diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h @@ -13,6 +13,10 @@ #ifndef MLIR_DIALECT_OPENMP_OPENMPDIALECT_H_ #define MLIR_DIALECT_OPENMP_OPENMPDIALECT_H_ +namespace mlir::omp { +class RTLModuleFlagsAttr; +} // namespace mlir::omp + #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -28,6 +28,17 @@ let cppNamespace = "::mlir::omp"; let dependentDialects = ["::mlir::LLVM::LLVMDialect"]; let useDefaultAttributePrinterParser = 1; + + let extraClassDeclaration = [{ + // Apply an omp.RTLModuleFlagsAttr to a module with the specified values for the + // flags + static void setRTLFlags(mlir::ModuleOp module, uint32_t debugKind, + bool assumeTeamsOversubscription, bool assumeThreadsOversubscription, + bool assumeNoThreadState, bool assumeNoNestedParallelism); + + // Return an omp.RTLModuleFlagsAttr from a given module, if it exists + static mlir::omp::RTLModuleFlagsAttr getRTLFlags(mlir::ModuleOp module); + }]; } // OmpCommon requires definition of OpenACC_Dialect. @@ -42,6 +53,34 @@ def OpenMP_PointerLikeType : TypeAlias; +// All of the attributes will extend this class. +class OpenMP_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { + let mnemonic = attrMnemonic; +} + +//===----------------------------------------------------------------------===// +// Runtime library flag's attribute that holds information for lowering to LLVM +//===----------------------------------------------------------------------===// + +def RTLModuleFlagsAttr : OpenMP_Attr<"RTLModuleFlags", "rtlmoduleflags"> { + let parameters = (ins + "uint32_t":$debug_kind, + "bool":$assume_teams_oversubscription, + "bool":$assume_threads_oversubscription, + "bool":$assume_no_thread_state, + "bool":$assume_no_nested_parallelism + ); + + let assemblyFormat = "`<` `debug_kind` `:` $debug_kind `,`" + " `assume_teams_oversubscription` `:` $assume_teams_oversubscription `,`" + " `assume_threads_oversubscription` `:` $assume_threads_oversubscription `,`" + " `assume_no_thread_state` `:` $assume_no_thread_state `,`" + " `assume_no_nested_parallelism` `:` $assume_no_nested_parallelism `>`"; +} + //===----------------------------------------------------------------------===// // 2.6 parallel Construct //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1417,6 +1417,32 @@ return success(); } +//===----------------------------------------------------------------------===// +// OpenMPDialect helper functions +//===----------------------------------------------------------------------===// + +// Apply an omp.RTLModuleFlagsAttr to a module with the specified values for the +// flags +void OpenMPDialect::setRTLFlags(mlir::ModuleOp module, uint32_t debugKind, + bool assumeTeamsOversubscription, + bool assumeThreadsOversubscription, + bool assumeNoThreadState, + bool assumeNoNestedParallelism) { + module->setAttr(("omp." + mlir::omp::RTLModuleFlagsAttr::getMnemonic()).str(), + mlir::omp::RTLModuleFlagsAttr::get( + module->getContext(), debugKind, + assumeTeamsOversubscription, + assumeThreadsOversubscription, assumeNoThreadState, + assumeNoNestedParallelism)); +} + +// Return an omp.RTLModuleFlagsAttr from a given module, if it exists +RTLModuleFlagsAttr OpenMPDialect::getRTLFlags(mlir::ModuleOp module) { + if (Attribute isDevice = module->getAttr("omp.rtlmoduleflags")) + return isDevice.dyn_cast_or_null(); + return nullptr; +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/attr.mlir b/mlir/test/Dialect/OpenMP/attr.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/OpenMP/attr.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +// CHECK: module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} { +module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} {} + +// ----- + +// CHECK: module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} { +module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} {} + +// ----- + +// CHECK: module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} { +module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} {} + +// ----- + +// CHECK: module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} { +module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} {} + +// ----- + +// CHECK: module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} { +module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} {}