diff --git a/llvm/include/llvm/Frontend/Directive/DirectiveBase.td b/llvm/include/llvm/Frontend/Directive/DirectiveBase.td --- a/llvm/include/llvm/Frontend/Directive/DirectiveBase.td +++ b/llvm/include/llvm/Frontend/Directive/DirectiveBase.td @@ -51,6 +51,21 @@ string flangClauseBaseClass = ""; } +// Information about values accepted by enum-like clauses +class ClauseVal { + // Name of the clause value. + string name = n; + + // Integer value of the clause. + int value = v; + + // Can user specify this value? + bit isUserValue = uv; + + // Set clause value used by default when unknown. + bit isDefault = 0; +} + // Information about a specific clause. class Clause { // Name of the clause. @@ -75,11 +90,17 @@ // If set to 1, value is optional. Not optional by default. bit isValueOptional = 0; + // Name of enum when there is a list of allowed clause values. + string enumClauseValue = ""; + + // List of allowed clause values + list allowedClauseValues = []; + // Is clause implicit? If clause is set as implicit, the default kind will // be return in getClauseKind instead of their own kind. bit isImplicit = 0; - // Set directive used by default when unknown. Function returning the kind + // Set clause used by default when unknown. Function returning the kind // of enumeration will use this clause as the default. bit isDefault = 0; } diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -99,9 +99,22 @@ let clangClass = "OMPCopyprivateClause"; let flangClassValue = "OmpObjectList"; } +def OMP_PROC_BIND_master : ClauseVal<"master",2,1> {} +def OMP_PROC_BIND_close : ClauseVal<"close",3,1> {} +def OMP_PROC_BIND_spread : ClauseVal<"spread",4,1> {} +def OMP_PROC_BIND_default : ClauseVal<"default",5,0> {} +def OMP_PROC_BIND_unknown : ClauseVal<"unknown",6,0> { let isDefault = 1; } def OMPC_ProcBind : Clause<"proc_bind"> { let clangClass = "OMPProcBindClause"; let flangClass = "OmpProcBindClause"; + let enumClauseValue = "ProcBindKind"; + let allowedClauseValues = [ + OMP_PROC_BIND_master, + OMP_PROC_BIND_close, + OMP_PROC_BIND_spread, + OMP_PROC_BIND_default, + OMP_PROC_BIND_unknown + ]; } def OMPC_Schedule : Clause<"schedule"> { let clangClass = "OMPScheduleClause"; diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h @@ -68,16 +68,6 @@ constexpr auto Enum = omp::DefaultKind::Enum; #include "llvm/Frontend/OpenMP/OMPKinds.def" -/// IDs for the different proc bind kinds. -enum class ProcBindKind { -#define OMP_PROC_BIND_KIND(Enum, Str, Value) Enum = Value, -#include "llvm/Frontend/OpenMP/OMPKinds.def" -}; - -#define OMP_PROC_BIND_KIND(Enum, ...) \ - constexpr auto Enum = omp::ProcBindKind::Enum; -#include "llvm/Frontend/OpenMP/OMPKinds.def" - /// IDs for all omp runtime library ident_t flag encodings (see /// their defintion in openmp/runtime/src/kmp.h). enum class IdentFlag { diff --git a/llvm/include/llvm/TableGen/DirectiveEmitter.h b/llvm/include/llvm/TableGen/DirectiveEmitter.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/TableGen/DirectiveEmitter.h @@ -0,0 +1,188 @@ +#ifndef LLVM_TABLEGEN_DIRECTIVEEMITTER_H +#define LLVM_TABLEGEN_DIRECTIVEEMITTER_H + +#include "llvm/ADT/StringExtras.h" +#include "llvm/TableGen/Record.h" + +namespace llvm { + +// Wrapper class that contains DirectiveLanguage's information defined in +// DirectiveBase.td and provides helper methods for accessing it. +class DirectiveLanguage { +public: + explicit DirectiveLanguage(const llvm::Record *Def) : Def(Def) {} + + StringRef getName() const { return Def->getValueAsString("name"); } + + StringRef getCppNamespace() const { + return Def->getValueAsString("cppNamespace"); + } + + StringRef getDirectivePrefix() const { + return Def->getValueAsString("directivePrefix"); + } + + StringRef getClausePrefix() const { + return Def->getValueAsString("clausePrefix"); + } + + StringRef getIncludeHeader() const { + return Def->getValueAsString("includeHeader"); + } + + StringRef getClauseEnumSetClass() const { + return Def->getValueAsString("clauseEnumSetClass"); + } + + StringRef getFlangClauseBaseClass() const { + return Def->getValueAsString("flangClauseBaseClass"); + } + + bool hasMakeEnumAvailableInNamespace() const { + return Def->getValueAsBit("makeEnumAvailableInNamespace"); + } + + bool hasEnableBitmaskEnumInNamespace() const { + return Def->getValueAsBit("enableBitmaskEnumInNamespace"); + } + +private: + const llvm::Record *Def; +}; + +// Base record class used for Directive and Clause class defined in +// DirectiveBase.td. +class BaseRecord { +public: + explicit BaseRecord(const llvm::Record *Def) : Def(Def) {} + + StringRef getName() const { return Def->getValueAsString("name"); } + + StringRef getAlternativeName() const { + return Def->getValueAsString("alternativeName"); + } + + // Returns the name of the directive formatted for output. Whitespace are + // replaced with underscores. + std::string getFormattedName() { + StringRef Name = Def->getValueAsString("name"); + std::string N = Name.str(); + std::replace(N.begin(), N.end(), ' ', '_'); + return N; + } + + bool isDefault() const { return Def->getValueAsBit("isDefault"); } + +protected: + const llvm::Record *Def; +}; + +// Wrapper class that contains a Directive's information defined in +// DirectiveBase.td and provides helper methods for accessing it. +class Directive : public BaseRecord { +public: + explicit Directive(const llvm::Record *Def) : BaseRecord(Def) {} + + std::vector getAllowedClauses() const { + return Def->getValueAsListOfDefs("allowedClauses"); + } + + std::vector getAllowedOnceClauses() const { + return Def->getValueAsListOfDefs("allowedOnceClauses"); + } + + std::vector getAllowedExclusiveClauses() const { + return Def->getValueAsListOfDefs("allowedExclusiveClauses"); + } + + std::vector getRequiredClauses() const { + return Def->getValueAsListOfDefs("requiredClauses"); + } +}; + +// Wrapper class that contains Clause's information defined in DirectiveBase.td +// and provides helper methods for accessing it. +class Clause : public BaseRecord { +public: + explicit Clause(const llvm::Record *Def) : BaseRecord(Def) {} + + // Optional field. + StringRef getClangClass() const { + return Def->getValueAsString("clangClass"); + } + + // Optional field. + StringRef getFlangClass() const { + return Def->getValueAsString("flangClass"); + } + + // Optional field. + StringRef getFlangClassValue() const { + return Def->getValueAsString("flangClassValue"); + } + + // Get the formatted name for Flang parser class. The generic formatted class + // name is constructed from the name were the first letter of each word is + // captitalized and the underscores are removed. + // ex: async -> Async + // num_threads -> NumThreads + std::string getFormattedParserClassName() { + StringRef Name = Def->getValueAsString("name"); + std::string N = Name.str(); + bool Cap = true; + std::transform(N.begin(), N.end(), N.begin(), [&Cap](unsigned char C) { + if (Cap == true) { + C = llvm::toUpper(C); + Cap = false; + } else if (C == '_') { + Cap = true; + } + return C; + }); + N.erase(std::remove(N.begin(), N.end(), '_'), N.end()); + return N; + } + + // Optional field. + StringRef getEnumName() const { + return Def->getValueAsString("enumClauseValue"); + } + + std::vector getClauseVals() const { + return Def->getValueAsListOfDefs("allowedClauseValues"); + } + + bool isValueOptional() const { return Def->getValueAsBit("isValueOptional"); } + + bool isImplict() const { return Def->getValueAsBit("isImplicit"); } +}; + +// Wrapper class that contains VersionedClause's information defined in +// DirectiveBase.td and provides helper methods for accessing it. +class VersionedClause { +public: + explicit VersionedClause(const llvm::Record *Def) : Def(Def) {} + + // Return the specific clause record wrapped in the Clause class. + Clause getClause() const { return Clause{Def->getValueAsDef("clause")}; } + + int64_t getMinVersion() const { return Def->getValueAsInt("minVersion"); } + + int64_t getMaxVersion() const { return Def->getValueAsInt("maxVersion"); } + +private: + const llvm::Record *Def; +}; + +class ClauseVal : public BaseRecord { +public: + explicit ClauseVal(const llvm::Record *Def) : BaseRecord(Def) {} + + int getValue() const { return Def->getValueAsInt("value"); } + + bool isUserVisible() const { return Def->getValueAsBit("isUserValue"); } +}; + +} // namespace llvm + +#endif diff --git a/llvm/test/TableGen/directive1.td b/llvm/test/TableGen/directive1.td --- a/llvm/test/TableGen/directive1.td +++ b/llvm/test/TableGen/directive1.td @@ -15,9 +15,20 @@ let flangClauseBaseClass = "TdlClause"; } +def TDLCV_vala : ClauseVal<"vala",1,1> {} +def TDLCV_valb : ClauseVal<"valb",2,1> {} +def TDLCV_valc : ClauseVal<"valc",3,0> { let isDefault = 1; } + def TDLC_ClauseA : Clause<"clausea"> { let flangClass = "TdlClauseA"; + let enumClauseValue = "AKind"; + let allowedClauseValues = [ + TDLCV_vala, + TDLCV_valb, + TDLCV_valc + ]; } + def TDLC_ClauseB : Clause<"clauseb"> { let flangClassValue = "IntExpr"; let isValueOptional = 1; @@ -61,6 +72,16 @@ // CHECK-NEXT: constexpr auto TDLC_clausea = llvm::tdl::Clause::TDLC_clausea; // CHECK-NEXT: constexpr auto TDLC_clauseb = llvm::tdl::Clause::TDLC_clauseb; // CHECK-EMPTY: +// CHECK-NEXT: enum class AKind { +// CHECK-NEXT: TDLCV_vala=1, +// CHECK-NEXT: TDLCV_valb=2, +// CHECK-NEXT: TDLCV_valc=3, +// CHECK-NEXT: }; +// CHECK-EMPTY: +// CHECK-NEXT: constexpr auto TDLCV_vala = llvm::tdl::AKind::TDLCV_vala; +// CHECK-NEXT: constexpr auto TDLCV_valb = llvm::tdl::AKind::TDLCV_valb; +// CHECK-NEXT: constexpr auto TDLCV_valc = llvm::tdl::AKind::TDLCV_valc; +// CHECK-EMPTY: // CHECK-NEXT: // Enumeration helper functions // CHECK-NEXT: Directive getTdlDirectiveKind(llvm::StringRef Str); // CHECK-EMPTY: @@ -73,6 +94,8 @@ // CHECK-NEXT: /// Return true if \p C is a valid clause for \p D in version \p Version. // CHECK-NEXT: bool isAllowedClauseForDirective(Directive D, Clause C, unsigned Version); // CHECK-EMPTY: +// CHECK-NEXT: AKind getAKind(StringRef); +// CHECK-EMPTY: // CHECK-NEXT: } // namespace tdl // CHECK-NEXT: } // namespace llvm // CHECK-NEXT: #endif // LLVM_Tdl_INC @@ -116,6 +139,14 @@ // IMPL-NEXT: llvm_unreachable("Invalid Tdl Clause kind"); // IMPL-NEXT: } // IMPL-EMPTY: +// IMPL-NEXT: AKind llvm::tdl::getAKind(llvm::StringRef Str) { +// IMPL-NEXT: return llvm::StringSwitch(Str) +// IMPL-NEXT: .Case("vala",TDLCV_vala) +// IMPL-NEXT: .Case("valb",TDLCV_valb) +// IMPL-NEXT: .Case("valc",TDLCV_valc) +// IMPL-NEXT: .Default(TDLCV_valc); +// IMPL-NEXT: } +// IMPL-EMPTY: // IMPL-NEXT: bool llvm::tdl::isAllowedClauseForDirective(Directive D, Clause C, unsigned Version) { // IMPL-NEXT: assert(unsigned(D) <= llvm::tdl::Directive_enumSize); // IMPL-NEXT: assert(unsigned(C) <= llvm::tdl::Clause_enumSize); diff --git a/llvm/utils/TableGen/DirectiveEmitter.cpp b/llvm/utils/TableGen/DirectiveEmitter.cpp --- a/llvm/utils/TableGen/DirectiveEmitter.cpp +++ b/llvm/utils/TableGen/DirectiveEmitter.cpp @@ -11,15 +11,14 @@ // //===----------------------------------------------------------------------===// +#include "llvm/TableGen/DirectiveEmitter.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" - using namespace llvm; namespace { @@ -41,165 +40,6 @@ namespace llvm { -// Wrapper class that contains DirectiveLanguage's information defined in -// DirectiveBase.td and provides helper methods for accessing it. -class DirectiveLanguage { -public: - explicit DirectiveLanguage(const llvm::Record *Def) : Def(Def) {} - - StringRef getName() const { return Def->getValueAsString("name"); } - - StringRef getCppNamespace() const { - return Def->getValueAsString("cppNamespace"); - } - - StringRef getDirectivePrefix() const { - return Def->getValueAsString("directivePrefix"); - } - - StringRef getClausePrefix() const { - return Def->getValueAsString("clausePrefix"); - } - - StringRef getIncludeHeader() const { - return Def->getValueAsString("includeHeader"); - } - - StringRef getClauseEnumSetClass() const { - return Def->getValueAsString("clauseEnumSetClass"); - } - - StringRef getFlangClauseBaseClass() const { - return Def->getValueAsString("flangClauseBaseClass"); - } - - bool hasMakeEnumAvailableInNamespace() const { - return Def->getValueAsBit("makeEnumAvailableInNamespace"); - } - - bool hasEnableBitmaskEnumInNamespace() const { - return Def->getValueAsBit("enableBitmaskEnumInNamespace"); - } - -private: - const llvm::Record *Def; -}; - -// Base record class used for Directive and Clause class defined in -// DirectiveBase.td. -class BaseRecord { -public: - explicit BaseRecord(const llvm::Record *Def) : Def(Def) {} - - StringRef getName() const { return Def->getValueAsString("name"); } - - StringRef getAlternativeName() const { - return Def->getValueAsString("alternativeName"); - } - - // Returns the name of the directive formatted for output. Whitespace are - // replaced with underscores. - std::string getFormattedName() { - StringRef Name = Def->getValueAsString("name"); - std::string N = Name.str(); - std::replace(N.begin(), N.end(), ' ', '_'); - return N; - } - - bool isDefault() const { return Def->getValueAsBit("isDefault"); } - -protected: - const llvm::Record *Def; -}; - -// Wrapper class that contains a Directive's information defined in -// DirectiveBase.td and provides helper methods for accessing it. -class Directive : public BaseRecord { -public: - explicit Directive(const llvm::Record *Def) : BaseRecord(Def) {} - - std::vector getAllowedClauses() const { - return Def->getValueAsListOfDefs("allowedClauses"); - } - - std::vector getAllowedOnceClauses() const { - return Def->getValueAsListOfDefs("allowedOnceClauses"); - } - - std::vector getAllowedExclusiveClauses() const { - return Def->getValueAsListOfDefs("allowedExclusiveClauses"); - } - - std::vector getRequiredClauses() const { - return Def->getValueAsListOfDefs("requiredClauses"); - } -}; - -// Wrapper class that contains Clause's information defined in DirectiveBase.td -// and provides helper methods for accessing it. -class Clause : public BaseRecord { -public: - explicit Clause(const llvm::Record *Def) : BaseRecord(Def) {} - - // Optional field. - StringRef getClangClass() const { - return Def->getValueAsString("clangClass"); - } - - // Optional field. - StringRef getFlangClass() const { - return Def->getValueAsString("flangClass"); - } - - // Optional field. - StringRef getFlangClassValue() const { - return Def->getValueAsString("flangClassValue"); - } - - // Get the formatted name for Flang parser class. The generic formatted class - // name is constructed from the name were the first letter of each word is - // captitalized and the underscores are removed. - // ex: async -> Async - // num_threads -> NumThreads - std::string getFormattedParserClassName() { - StringRef Name = Def->getValueAsString("name"); - std::string N = Name.str(); - bool Cap = true; - std::transform(N.begin(), N.end(), N.begin(), [&Cap](unsigned char C) { - if (Cap == true) { - C = llvm::toUpper(C); - Cap = false; - } else if (C == '_') { - Cap = true; - } - return C; - }); - N.erase(std::remove(N.begin(), N.end(), '_'), N.end()); - return N; - } - - bool isValueOptional() const { return Def->getValueAsBit("isValueOptional"); } - - bool isImplict() const { return Def->getValueAsBit("isImplicit"); } -}; - -// Wrapper class that contains VersionedClause's information defined in -// DirectiveBase.td and provides helper methods for accessing it. -class VersionedClause { -public: - explicit VersionedClause(const llvm::Record *Def) : Def(Def) {} - - // Return the specific clause record wrapped in the Clause class. - Clause getClause() const { return Clause{Def->getValueAsDef("clause")}; } - - int64_t getMinVersion() const { return Def->getValueAsInt("minVersion"); } - - int64_t getMaxVersion() const { return Def->getValueAsInt("maxVersion"); } - -private: - const llvm::Record *Def; -}; - // Generate enum class void GenerateEnumClass(const std::vector &Records, raw_ostream &OS, StringRef Enum, StringRef Prefix, @@ -231,6 +71,46 @@ } } +// Generate enums for values that clauses can take. +// Also generate function declarations for getName(StringRef Str). +void GenerateEnumClauseVal(const std::vector &Records, + raw_ostream &OS, DirectiveLanguage &DirLang, + std::string &EnumHelperFuncs) { + for (const auto &R : Records) { + Clause C{R}; + const auto &ClauseVals = C.getClauseVals(); + if (ClauseVals.size() <= 0) + continue; + + const auto &EnumName = C.getEnumName(); + if (EnumName.size() == 0) { + PrintError("enumClauseValue field not set in Clause" + + C.getFormattedName() + "."); + return; + } + + OS << "\n"; + OS << "enum class " << EnumName << " {\n"; + for (const auto &CV : ClauseVals) { + ClauseVal CVal{CV}; + OS << " " << CV->getName() << "=" << CVal.getValue() << ",\n"; + } + OS << "};\n"; + + if (DirLang.hasMakeEnumAvailableInNamespace()) { + OS << "\n"; + for (const auto &CV : ClauseVals) { + OS << "constexpr auto " << CV->getName() << " = " + << "llvm::" << DirLang.getCppNamespace() << "::" << EnumName + << "::" << CV->getName() << ";\n"; + } + EnumHelperFuncs += (llvm::Twine(EnumName) + llvm::Twine(" get") + + llvm::Twine(EnumName) + llvm::Twine("(StringRef);\n")) + .str(); + } + } +} + // Generate the declaration section for the enumeration in the directive // language void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) { @@ -273,6 +153,10 @@ const auto &Clauses = Records.getAllDerivedDefinitions("Clause"); GenerateEnumClass(Clauses, OS, "Clause", DirLang.getClausePrefix(), DirLang); + // Emit ClauseVal enumeration + std::string EnumHelperFuncs; + GenerateEnumClauseVal(Clauses, OS, DirLang, EnumHelperFuncs); + // Generic function signatures OS << "\n"; OS << "// Enumeration helper functions\n"; @@ -292,6 +176,10 @@ OS << "bool isAllowedClauseForDirective(Directive D, " << "Clause C, unsigned Version);\n"; OS << "\n"; + if (EnumHelperFuncs.length() > 0) { + OS << EnumHelperFuncs; + OS << "\n"; + } // Closing namespaces for (auto Ns : llvm::reverse(Namespaces)) @@ -336,7 +224,7 @@ }); if (DefaultIt == Records.end()) { - PrintError("A least one " + Enum + " must be defined as default."); + PrintError("At least one " + Enum + " must be defined as default."); return; } @@ -361,6 +249,49 @@ OS << "}\n"; } +// Generate function implementation for getKind(StringRef Str) +void GenerateGetKindClauseVal(const std::vector &Records, + raw_ostream &OS, StringRef Namespace) { + + for (const auto &R : Records) { + Clause C{R}; + const auto &ClauseVals = C.getClauseVals(); + if (ClauseVals.size() <= 0) + continue; + + auto DefaultIt = + std::find_if(ClauseVals.begin(), ClauseVals.end(), [](Record *CV) { + return CV->getValueAsBit("isDefault") == true; + }); + + if (DefaultIt == ClauseVals.end()) { + PrintError("At least one val in Clause " + C.getFormattedName() + + " must be defined as default."); + return; + } + const auto DefaultName = (*DefaultIt)->getName(); + + const auto &EnumName = C.getEnumName(); + if (EnumName.size() == 0) { + PrintError("enumClauseValue field not set in Clause" + + C.getFormattedName() + "."); + return; + } + + OS << "\n"; + OS << EnumName << " llvm::" << Namespace << "::get" << EnumName + << "(llvm::StringRef Str) {\n"; + OS << " return llvm::StringSwitch<" << EnumName << ">(Str)\n"; + for (const auto &CV : ClauseVals) { + ClauseVal CVal{CV}; + OS << " .Case(\"" << CVal.getFormattedName() << "\"," << CV->getName() + << ")\n"; + } + OS << " .Default(" << DefaultName << ");\n"; + OS << "}\n"; + } +} + void GenerateCaseForVersionedClauses(const std::vector &Clauses, raw_ostream &OS, StringRef DirectiveName, DirectiveLanguage &DirLang, @@ -672,6 +603,9 @@ // getClauseName(Clause Kind) GenerateGetName(Clauses, OS, "Clause", DirLang, DirLang.getClausePrefix()); + // getKind(StringRef Str) + GenerateGetKindClauseVal(Clauses, OS, DirLang.getCppNamespace()); + // isAllowedClauseForDirective(Directive D, Clause C, unsigned Version) GenerateIsAllowedClause(Directives, OS, DirLang); } diff --git a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt --- a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt @@ -1,3 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend/OpenMP/OMP.td) +mlir_tablegen(OmpCommon.td --gen-directive-decl) +add_public_tablegen_target(omp_common_td) + set(LLVM_TARGET_DEFINITIONS OpenMPOps.td) mlir_tablegen(OpenMPOpsDialect.h.inc -gen-dialect-decls -dialect=omp) mlir_tablegen(OpenMPOps.h.inc -gen-op-decls) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -15,6 +15,7 @@ #define OPENMP_OPS include "mlir/IR/OpBase.td" +include "mlir/Dialect/OpenMP/OmpCommon.td" def OpenMP_Dialect : Dialect { let name = "omp"; @@ -42,18 +43,6 @@ let cppNamespace = "::mlir::omp"; } -// Possible values for the proc_bind clause -def ClauseProcMaster : StrEnumAttrCase<"master">; -def ClauseProcClose : StrEnumAttrCase<"close">; -def ClauseProcSpread : StrEnumAttrCase<"spread">; - -def ClauseProcBind : StrEnumAttr< - "ClauseProcBind", - "procbind clause", - [ClauseProcMaster, ClauseProcClose, ClauseProcSpread]> { - let cppNamespace = "::mlir::omp"; -} - def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments]> { let summary = "parallel construct"; let description = [{ @@ -87,7 +76,7 @@ Variadic:$firstprivate_vars, Variadic:$shared_vars, Variadic:$copyin_vars, - OptionalAttr:$proc_bind_val); + OptionalAttr:$proc_bind_val); let regions = (region AnyRegion:$region); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -408,32 +408,31 @@ blockMapping[&bb] = llvmBB; } - // Then, convert blocks one by one in topological order to ensure - // defs are converted before uses. - llvm::SetVector blocks = topologicalSort(region); - for (auto indexedBB : llvm::enumerate(blocks)) { - Block *bb = indexedBB.value(); - llvm::BasicBlock *curLLVMBB = blockMapping[bb]; - if (bb->isEntryBlock()) - codeGenIPBBTI->setSuccessor(0, curLLVMBB); - - // TODO: Error not returned up the hierarchy - if (failed( - convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0))) - return; - - // If this block has the terminator then add a jump to - // continuation bb - for (auto &op : *bb) { - if (isa(op)) { - builder.SetInsertPoint(curLLVMBB); - builder.CreateBr(&continuationIP); - } + // Then, convert blocks one by one in topological order to ensure + // defs are converted before uses. + llvm::SetVector blocks = topologicalSort(region); + for (auto indexedBB : llvm::enumerate(blocks)) { + Block *bb = indexedBB.value(); + llvm::BasicBlock *curLLVMBB = blockMapping[bb]; + if (bb->isEntryBlock()) + codeGenIPBBTI->setSuccessor(0, curLLVMBB); + + // TODO: Error not returned up the hierarchy + if (failed(convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0))) + return; + + // If this block has the terminator then add a jump to + // continuation bb + for (auto &op : *bb) { + if (isa(op)) { + builder.SetInsertPoint(curLLVMBB); + builder.CreateBr(&continuationIP); } } - // Finally, after all blocks have been traversed and values mapped, - // connect the PHI nodes to the results of preceding blocks. - connectPHINodes(region, valueMapping, blockMapping); + } + // Finally, after all blocks have been traversed and values mapped, + // connect the PHI nodes to the results of preceding blocks. + connectPHINodes(region, valueMapping, blockMapping); }; // TODO: Perform appropriate actions according to the data-sharing @@ -451,23 +450,24 @@ // called for variables which have destructors/finalizers. auto finiCB = [&](InsertPointTy codeGenIP) {}; - // TODO: The various operands of parallel operation are not handled. - // Parallel operation is created with some default options for now. llvm::Value *ifCond = nullptr; if (auto ifExprVar = cast(opInst).if_expr_var()) ifCond = valueMapping.lookup(ifExprVar); llvm::Value *numThreads = nullptr; if (auto numThreadsVar = cast(opInst).num_threads_var()) numThreads = valueMapping.lookup(numThreadsVar); + llvm::omp::ProcBindKind pbKind = llvm::omp::OMP_PROC_BIND_default; + if (auto bind = cast(opInst).proc_bind_val()) + pbKind = llvm::omp::getProcBindKind(bind.getValue()); // TODO: Is the Parallel construct cancellable? bool isCancellable = false; // TODO: Determine the actual alloca insertion point, e.g., the function // entry or the alloca insertion point as provided by the body callback // above. llvm::OpenMPIRBuilder::InsertPointTy allocaIP(builder.saveIP()); - builder.restoreIP(ompBuilder->CreateParallel( - builder, allocaIP, bodyGenCB, privCB, finiCB, ifCond, numThreads, - llvm::omp::OMP_PROC_BIND_default, isCancellable)); + builder.restoreIP( + ompBuilder->CreateParallel(builder, allocaIP, bodyGenCB, privCB, finiCB, + ifCond, numThreads, pbKind, isCancellable)); return success(); } diff --git a/mlir/test/Target/openmp-llvm.mlir b/mlir/test/Target/openmp-llvm.mlir --- a/mlir/test/Target/openmp-llvm.mlir +++ b/mlir/test/Target/openmp-llvm.mlir @@ -175,3 +175,34 @@ // CHECK: define internal void @[[OMP_OUTLINED_FN_IF_1]] // CHECK: call void @__kmpc_barrier + +// CHECK-LABEL: define void @test_omp_parallel_3() +llvm.func @test_omp_parallel_3() -> () { + // CHECK: [[OMP_THREAD_3_1:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}}) + // CHECK: call void @__kmpc_push_proc_bind(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD_3_1]], i32 2) + // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_3_1:.*]] to {{.*}} + omp.parallel proc_bind(master) { + omp.barrier + omp.terminator + } + // CHECK: [[OMP_THREAD_3_2:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}}) + // CHECK: call void @__kmpc_push_proc_bind(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD_3_2]], i32 3) + // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_3_2:.*]] to {{.*}} + omp.parallel proc_bind(close) { + omp.barrier + omp.terminator + } + // CHECK: [[OMP_THREAD_3_3:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}}) + // CHECK: call void @__kmpc_push_proc_bind(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD_3_3]], i32 4) + // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_3_3:.*]] to {{.*}} + omp.parallel proc_bind(spread) { + omp.barrier + omp.terminator + } + + llvm.return +} + +// CHECK: define internal void @[[OMP_OUTLINED_FN_3_3]] +// CHECK: define internal void @[[OMP_OUTLINED_FN_3_2]] +// CHECK: define internal void @[[OMP_OUTLINED_FN_3_1]] diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -14,6 +14,7 @@ OpDocGen.cpp OpFormatGen.cpp OpInterfacesGen.cpp + OpenMPCommonGen.cpp PassGen.cpp PassDocGen.cpp RewriterGen.cpp diff --git a/mlir/tools/mlir-tblgen/OpenMPCommonGen.cpp b/mlir/tools/mlir-tblgen/OpenMPCommonGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/OpenMPCommonGen.cpp @@ -0,0 +1,73 @@ +//===========- OpenMPCommonGen.cpp - OpenMP common info generator -===========// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// OpenMPCommonGen generates utility information from the single OpenMP source +// of truth in llvm/lib/Frontend/OpenMP. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/GenInfo.h" + +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/DirectiveEmitter.h" +#include "llvm/TableGen/Record.h" + +using llvm::Clause; +using llvm::ClauseVal; +using llvm::raw_ostream; +using llvm::RecordKeeper; +using llvm::Twine; + +static bool emitDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { + const auto &clauses = recordKeeper.getAllDerivedDefinitions("Clause"); + + for (const auto &r : clauses) { + Clause c{r}; + const auto &clauseVals = c.getClauseVals(); + if (clauseVals.size() <= 0) + continue; + + const auto enumName = c.getEnumName(); + assert(enumName.size() != 0 && "enumClauseValue field not set."); + + std::vector cvDefs; + for (const auto &cv : clauseVals) { + ClauseVal cval{cv}; + if (!cval.isUserVisible()) + continue; + + const auto name = cval.getFormattedName(); + std::string cvDef{(enumName + llvm::Twine(name)).str()}; + os << "def " << cvDef << " : StrEnumAttrCase<\"" << name << "\">;\n"; + cvDefs.push_back(cvDef); + } + + os << "def " << enumName << ": StrEnumAttr<\n"; + os << " \"Clause" << enumName << "\",\n"; + os << " \"" << enumName << " Clause\",\n"; + os << " ["; + for (unsigned int i = 0; i < cvDefs.size(); i++) { + os << cvDefs[i]; + if (i != cvDefs.size() - 1) + os << ","; + } + os << "]> {\n"; + os << " let cppNamespace = \"::mlir::omp\";\n"; + os << "}\n"; + } + return false; +} + +// Registers the generator to mlir-tblgen. +static mlir::GenRegistration + genDirectiveDecls("gen-directive-decl", + "Generate declarations for directives (OpenMP etc.)", + [](const RecordKeeper &records, raw_ostream &os) { + return emitDecls(records, os); + });