diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h @@ -13,6 +13,10 @@ #ifndef MLIR_DIALECT_OPENMP_OPENMPDIALECT_H_ #define MLIR_DIALECT_OPENMP_OPENMPDIALECT_H_ +namespace mlir::omp { +class RTLModuleFlagsAttr; +} // namespace mlir::omp + #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -28,6 +28,17 @@ let cppNamespace = "::mlir::omp"; let dependentDialects = ["::mlir::LLVM::LLVMDialect"]; let useDefaultAttributePrinterParser = 1; + + let extraClassDeclaration = [{ + // Apply an omp.RTLModuleFlagsAttr to a module with the specified values for the + // flags + static void setRTLFlags(Operation* module, uint32_t debugKind, + bool assumeTeamsOversubscription, bool assumeThreadsOversubscription, + bool assumeNoThreadState, bool assumeNoNestedParallelism); + + // Return an omp.RTLModuleFlagsAttr from a given module, if it exists + static mlir::omp::RTLModuleFlagsAttr getRTLFlags(Operation* module); + }]; } // OmpCommon requires definition of OpenACC_Dialect. @@ -42,6 +53,29 @@ def OpenMP_PointerLikeType : TypeAlias; +class OpenMP_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { + let mnemonic = attrMnemonic; +} + +//===----------------------------------------------------------------------===// +// Runtime library flag's attribute that holds information for lowering to LLVM +//===----------------------------------------------------------------------===// + +def RTLModuleFlagsAttr : OpenMP_Attr<"RTLModuleFlags", "rtlmoduleflags"> { + let parameters = (ins + DefaultValuedParameter<"uint32_t", "0">:$debug_kind, + DefaultValuedParameter<"bool", "false">:$assume_teams_oversubscription, + DefaultValuedParameter<"bool", "false">:$assume_threads_oversubscription, + DefaultValuedParameter<"bool", "false">:$assume_no_thread_state, + DefaultValuedParameter<"bool", "false">:$assume_no_nested_parallelism + ); + + let assemblyFormat = "`<` struct(params) `>`"; +} + //===----------------------------------------------------------------------===// // 2.6 parallel Construct //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1417,6 +1417,32 @@ return success(); } +//===----------------------------------------------------------------------===// +// OpenMPDialect helper functions +//===----------------------------------------------------------------------===// + +// Apply an omp.RTLModuleFlagsAttr to a module with the specified values for the +// flags +void OpenMPDialect::setRTLFlags(Operation* module, uint32_t debugKind, + bool assumeTeamsOversubscription, + bool assumeThreadsOversubscription, + bool assumeNoThreadState, + bool assumeNoNestedParallelism) { + module->setAttr(("omp." + mlir::omp::RTLModuleFlagsAttr::getMnemonic()).str(), + mlir::omp::RTLModuleFlagsAttr::get( + module->getContext(), debugKind, + assumeTeamsOversubscription, + assumeThreadsOversubscription, assumeNoThreadState, + assumeNoNestedParallelism)); +} + +// Return an omp.RTLModuleFlagsAttr from a given module, if it exists +RTLModuleFlagsAttr OpenMPDialect::getRTLFlags(Operation* module) { + if (Attribute isDevice = module->getAttr("omp.rtlmoduleflags")) + return isDevice.dyn_cast_or_null(); + return nullptr; +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/attr.mlir b/mlir/test/Dialect/OpenMP/attr.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/OpenMP/attr.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +// CHECK: module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags<>} { +module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} {} + +// CHECK: module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} { +module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} {} + +// CHECK: module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} { +module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} {} + +// CHECK: module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} { +module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} {} + +// CHECK: module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} { +module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} {} + +// CHECK: module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} { +module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} {} + +// CHECK: module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags<>} { +module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags<>} {} + +// CHECK: module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} { +module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} {} + +// CHECK: module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} { +module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} {} + +// CHECK: module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} { +module attributes {omp.rtlmoduleflags = #omp.rtlmoduleflags} {}