diff --git a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt --- a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt @@ -7,6 +7,8 @@ mlir_tablegen(OpenMPOpsDialect.cpp.inc -gen-dialect-defs -dialect=omp) mlir_tablegen(OpenMPOps.h.inc -gen-op-decls) mlir_tablegen(OpenMPOps.cpp.inc -gen-op-defs) +mlir_tablegen(OpenMPOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=omp) +mlir_tablegen(OpenMPOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=omp) mlir_tablegen(OpenMPOpsEnums.h.inc -gen-enum-decls) mlir_tablegen(OpenMPOpsEnums.cpp.inc -gen-enum-defs) mlir_tablegen(OpenMPOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=omp) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h @@ -22,6 +22,9 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.h.inc" + #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.h.inc" #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc" #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.h.inc" diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -29,6 +29,7 @@ let cppNamespace = "::mlir::omp"; let dependentDialects = ["::mlir::LLVM::LLVMDialect, ::mlir::func::FuncDialect"]; let useDefaultAttributePrinterParser = 1; + let useDefaultTypePrinterParser = 1; } // OmpCommon requires definition of OpenACC_Dialect. @@ -89,6 +90,10 @@ def OpenMP_PointerLikeType : TypeAlias; +class OpenMP_Type : TypeDef { + let mnemonic = typeMnemonic; +} + //===----------------------------------------------------------------------===// // 2.12.7 Declare Target Directive //===----------------------------------------------------------------------===// @@ -1004,6 +1009,156 @@ }]; } +//===----------------------------------------------------------------------===// +// Map related constructs +//===----------------------------------------------------------------------===// + +def CaptureThis : I32EnumAttrCase<"This", 0>; +def CaptureByRef : I32EnumAttrCase<"ByRef", 1>; +def CaptureByCopy : I32EnumAttrCase<"ByCopy", 2>; +def CaptureVLAType : I32EnumAttrCase<"VLAType", 3>; + +def VariableCaptureKind : I32EnumAttr< + "VariableCaptureKind", + "variable capture kind", + [CaptureThis, CaptureByRef, CaptureByCopy, CaptureVLAType]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::omp"; +} + +def VariableCaptureKindAttr : EnumAttr { + let assemblyFormat = "`(` $value `)`"; +} + +def DataBoundsType : OpenMP_Type<"DataBounds", "data_bounds_ty"> { + let summary = "Type for representing omp data clause bounds information"; +} + +def DataBoundsOp : OpenMP_Op<"bounds", + [AttrSizedOperandSegments, NoMemoryEffect]> { + let summary = "Represents normalized bounds information for map and data clauses."; + + let description = [{ + This operation is a variation on the OpenACC dialects DataBoundsOp + it is used to record bounds used in map and data clauses in a + normalized fashion (zero-based). This works well with the `PointerLikeType` + requirement in data clauses - since a `lower_bound` of 0 means looking + at data at the zero offset from pointer. + + The operation must have an `upper_bound` or `extent` (or both are allowed - + but not checked for consistency). When the source language's arrays are + not zero-based, the `start_idx` must specify the zero-position index. + + Examples below show copying a slice of 10-element array except first element. + To simplify examples, the constants are used directly in the acc.bounds + operands - this is not the syntax of operation. + + C++: + ``` + int array[10]; + #pragma target map(array[1:9]) + ``` + => + ```mlir + omp.bounds lower_bound(1) upper_bound(9) extent(9) start_idx(0) + ``` + + Fortran: + ``` + integer :: array(1:10) + !$target map(array(2:10)) + ``` + => + ```mlir + omp.bounds lower_bound(1) upper_bound(9) extent(9) start_idx(1) + ``` + }]; + + let arguments = (ins Optional:$lower_bound, + Optional:$upper_bound, + Optional:$extent, + Optional:$stride, + DefaultValuedAttr:$stride_in_bytes, + Optional:$start_idx); + let results = (outs DataBoundsType:$result); + + let assemblyFormat = [{ + oilist( + `lower_bound` `(` $lower_bound `:` type($lower_bound) `)` + | `upper_bound` `(` $upper_bound `:` type($upper_bound) `)` + | `extent` `(` $extent `:` type($extent) `)` + | `stride` `(` $stride `:` type($stride) `)` + | `start_idx` `(` $start_idx `:` type($start_idx) `)` + ) attr-dict + }]; + + let extraClassDeclaration = [{ + /// The number of variable operands. + unsigned getNumVariableOperands() { + return getNumOperands(); + } + + /// The i-th variable operand passed. + Value getVariableOperand(unsigned i) { + return getOperands()[i]; + } + }]; + + let hasVerifier = 1; +} + +def MapEntryOp : OpenMP_Op<"map_entry", [AttrSizedOperandSegments]> { + let arguments = (ins OpenMP_PointerLikeType:$var_ptr, + Optional:$var_ptr_ptr, + Variadic:$bounds, /* rank-0 to rank-{n-1} */ + OptionalAttr:$map_type, + OptionalAttr:$map_capture_type, + DefaultValuedAttr:$implicit, + OptionalAttr:$name); + let results = (outs OpenMP_PointerLikeType:$omp_ptr); + + let description = [{ + Description of arguments: + - `var_ptr`: The address of variable to copy. + - `var_ptr_ptr`: Specifies the address of varPtr - only used when the variable + copied is a field in a struct. + - `bounds`: Used when copying just slice of array or array's bounds are not + encoded in type. They are in rank order where rank 0 is inner-most dimension. + - `implicit`: specified explicitly in a map clause or captured implicitly by + being used in a target region with no map or other data mapping construct. + - 'map_clauses': OpenMP map type for this map capture, for example: from, to and + always, usually mixed in with other map type signifiers such as if it's implicit + or a parameter. It's a bitfield composed of the OpenMP runtime flags + stored in OpenMPOffloadMappingFlags. + - 'map_capture_type': Capture type for the variable e.g. byref, byvalue, byvla this + can affect lowering. + - `name`: Holds the name of variable as specified in user clause (including bounds). + }]; + + let assemblyFormat = [{ + `var_ptr` `(` $var_ptr `:` type($var_ptr) `)` + oilist( + `var_ptr_ptr` `(` $var_ptr_ptr `:` type($var_ptr_ptr) `)` + | `map_clauses` `(` custom($map_type) `)` + | `capture` `(` custom($map_capture_type) `)` + | `bounds` `(` $bounds `)` + ) `->` type($omp_ptr) attr-dict + }]; + + let extraClassDeclaration = [{ + /// The number of variable operands. + unsigned getNumVariableOperands() { + return getNumOperands(); + } + + /// The i-th variable operand passed. + Value getVariableOperand(unsigned i) { + return getOperands()[i]; + } + }]; +} + //===---------------------------------------------------------------------===// // 2.14.2 target data Construct //===---------------------------------------------------------------------===// @@ -1044,16 +1199,14 @@ Optional:$device, Variadic:$use_device_ptr, Variadic:$use_device_addr, - Variadic:$map_operands, - OptionalAttr:$map_types); + Variadic:$map_operands); let regions = (region AnyRegion:$region); let assemblyFormat = [{ oilist(`if` `(` $if_expr `:` type($if_expr) `)` | `device` `(` $device `:` type($device) `)` - | `map` - `(` custom($map_operands, type($map_operands), $map_types) `)` + | `map_entries` `(` $map_operands `:` type($map_operands) `)` | `use_device_ptr` `(` $use_device_ptr `:` type($use_device_ptr) `)` | `use_device_addr` `(` $use_device_addr `:` type($use_device_addr) `)`) $region attr-dict @@ -1095,15 +1248,14 @@ let arguments = (ins Optional:$if_expr, Optional:$device, UnitAttr:$nowait, - Variadic:$map_operands, - I64ArrayAttr:$map_types); + Variadic:$map_operands); let assemblyFormat = [{ oilist(`if` `(` $if_expr `:` type($if_expr) `)` | `device` `(` $device `:` type($device) `)` - | `nowait` $nowait) - `map` `(` custom($map_operands, type($map_operands), $map_types) `)` - attr-dict + | `nowait` $nowait + | `map_entries` `(` $map_operands `:` type($map_operands) `)` + ) attr-dict }]; let hasVerifier = 1; @@ -1142,15 +1294,14 @@ let arguments = (ins Optional:$if_expr, Optional:$device, UnitAttr:$nowait, - Variadic:$map_operands, - I64ArrayAttr:$map_types); + Variadic:$map_operands); let assemblyFormat = [{ oilist(`if` `(` $if_expr `:` type($if_expr) `)` | `device` `(` $device `:` type($device) `)` - | `nowait` $nowait) - `map` `(` custom($map_operands, type($map_operands), $map_types) `)` - attr-dict + | `nowait` $nowait + | `map_entries` `(` $map_operands `:` type($map_operands) `)` + ) attr-dict }]; let hasVerifier = 1; @@ -1186,8 +1337,7 @@ Optional:$device, Optional:$thread_limit, UnitAttr:$nowait, - Variadic:$map_operands, - OptionalAttr:$map_types); + Variadic:$map_operands); let regions = (region AnyRegion:$region); @@ -1196,7 +1346,7 @@ | `device` `(` $device `:` type($device) `)` | `thread_limit` `(` $thread_limit `:` type($thread_limit) `)` | `nowait` $nowait - | `map` `(` custom($map_operands, type($map_operands), $map_types) `)` + | `map_entries` `(` $map_operands `:` type($map_operands) `)` ) $region attr-dict }]; diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -207,10 +207,10 @@ typeConverter.isLegal(op->getOperandTypes()) && typeConverter.isLegal(op->getResultTypes()); }); - target.addDynamicallyLegalOp( + target.addDynamicallyLegalOp< + mlir::omp::AtomicReadOp, mlir::omp::AtomicWriteOp, mlir::omp::FlushOp, + mlir::omp::ThreadprivateOp, mlir::omp::YieldOp, mlir::omp::EnterDataOp, + mlir::omp::ExitDataOp, mlir::omp::DataBoundsOp, mlir::omp::MapEntryOp>( [&](Operation *op) { return typeConverter.isLegal(op->getOperandTypes()) && typeConverter.isLegal(op->getResultTypes()); @@ -230,6 +230,12 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { + // This type is allowed when converting OpenMP to LLVM Dialect, it carries + // bounds information for map clauses and the operation and type are + // discarded on lowering to LLVM-IR from the OpenMP dialect. + converter.addConversion( + [&](omp::DataBoundsType type) -> Type { return type; }); + patterns.add< AtomicReadOpConversion, ReductionOpConversion, ReductionDeclareOpConversion, RegionOpConversion, @@ -245,7 +251,9 @@ RegionLessOpWithVarOperandsConversion, RegionLessOpConversion, RegionLessOpConversion, - RegionLessOpConversion>(converter); + RegionLessOpConversion, + RegionLessOpWithVarOperandsConversion, + RegionLessOpWithVarOperandsConversion>(converter); } namespace { diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -66,6 +66,10 @@ #define GET_ATTRDEF_LIST #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc" + >(); addInterface(); LLVM::LLVMPointerType::attachInterface< @@ -660,187 +664,260 @@ //===----------------------------------------------------------------------===// // Parser, printer and verifier for Target //===----------------------------------------------------------------------===// -/// Parses a Map Clause. -/// -/// map-clause = `map (` ( `(` `always, `? `close, `? `present, `? ( `to` | -/// `from` | `delete` ) ` -> ` symbol-ref ` : ` type(symbol-ref) `)` )+ `)` -/// Eg: map((release -> %1 : !llvm.ptr>), (always, close, from -/// -> %2 : !llvm.ptr>)) -static ParseResult -parseMapClause(OpAsmParser &parser, - SmallVectorImpl &map_operands, - SmallVectorImpl &map_operand_types, ArrayAttr &map_types) { - StringRef mapTypeMod; - OpAsmParser::UnresolvedOperand arg1; - Type arg1Type; - IntegerAttr arg2; - SmallVector mapTypesVec; - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits; +// Helper function to get bitwise AND of `value` and 'flag' +uint64_t mapTypeToBitFlag(uint64_t value, + llvm::omp::OpenMPOffloadMappingFlags flag) { + return value & + static_cast< + std::underlying_type_t>( + flag); +} + +/// Parses a map_entries map type from a string format back into its numeric +/// value. +/// +/// map-clause = `map_clauses ( `always, `? `close, `? `present, `? ( +/// `to` | `from` | `delete` | `delete` | `exit_release_or_enter_alloc` )` +/// `, `? `, `? `ptr_and_obj, `? `target_param, `? `return_param, `? +/// `private, `? `literal, `? `implicit, `? `ompx_hold, `? `non_contig, `? +/// `member_of`? `)` +static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) { + llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + + // This simply verifies the correct keyword is read in, the + // keyword itself is stored inside of the operation auto parseTypeAndMod = [&]() -> ParseResult { + StringRef mapTypeMod; if (parser.parseKeyword(&mapTypeMod)) return failure(); if (mapTypeMod == "always") mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; + if (mapTypeMod == "close") mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE; + if (mapTypeMod == "present") mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT; if (mapTypeMod == "to") mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; + if (mapTypeMod == "from") mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + if (mapTypeMod == "tofrom") mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + if (mapTypeMod == "delete") mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; - return success(); - }; - auto parseMap = [&]() -> ParseResult { - mapTypeBits = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + if (mapTypeMod == "ptr_and_obj") + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ; + + if (mapTypeMod == "target_param") + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM; + + if (mapTypeMod == "return_param") + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; + + if (mapTypeMod == "private") + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE; + + if (mapTypeMod == "literal") + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL; + + if (mapTypeMod == "implicit") + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; + + if (mapTypeMod == "ompx_hold") + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD; + + if (mapTypeMod == "non_contig") + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NON_CONTIG; + + if (mapTypeMod == "member_of") + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF; - if (parser.parseLParen() || - parser.parseCommaSeparatedList(parseTypeAndMod) || - parser.parseArrow() || parser.parseOperand(arg1) || - parser.parseColon() || parser.parseType(arg1Type) || - parser.parseRParen()) - return failure(); - map_operands.push_back(arg1); - map_operand_types.push_back(arg1Type); - arg2 = parser.getBuilder().getIntegerAttr( - parser.getBuilder().getI64Type(), - static_cast< - std::underlying_type_t>( - mapTypeBits)); - mapTypesVec.push_back(arg2); return success(); }; - if (parser.parseCommaSeparatedList(parseMap)) + if (parser.parseCommaSeparatedList(parseTypeAndMod)) return failure(); - SmallVector mapTypesAttr(mapTypesVec.begin(), mapTypesVec.end()); - map_types = ArrayAttr::get(parser.getContext(), mapTypesAttr); + mapType = parser.getBuilder().getIntegerAttr( + parser.getBuilder().getIntegerType(64, /*isSigned=*/false), + static_cast>( + mapTypeBits)); + return success(); } +/// Prints a map_entries map type from its numeric value out into its string +/// format. static void printMapClause(OpAsmPrinter &p, Operation *op, - OperandRange map_operands, - TypeRange map_operand_types, ArrayAttr map_types) { - - // Helper function to get bitwise AND of `value` and 'flag' - auto bitAnd = [](int64_t value, - llvm::omp::OpenMPOffloadMappingFlags flag) -> bool { - return value & - static_cast< - std::underlying_type_t>( - flag); - }; - - assert(map_operands.size() == map_types.size()); - - for (unsigned i = 0, e = map_operands.size(); i < e; i++) { - int64_t mapTypeBits = 0x00; - Value mapOp = map_operands[i]; - Attribute mapTypeOp = map_types[i]; - - assert(llvm::isa(mapTypeOp)); - mapTypeBits = llvm::cast(mapTypeOp).getInt(); - - bool always = bitAnd(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS); - bool close = bitAnd(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE); - bool present = bitAnd( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT); - - bool to = - bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); - bool from = - bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); - bool del = bitAnd(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE); - - std::string typeModStr, typeStr; - llvm::raw_string_ostream typeMod(typeModStr), type(typeStr); - - if (always) - typeMod << "always, "; - if (close) - typeMod << "close, "; - if (present) - typeMod << "present, "; - - if (to) - type << "to"; - if (from) - type << "from"; - if (del) - type << "delete"; - if (type.str().empty()) - type << (isa(op) ? "release" : "alloc"); - - p << '(' << typeMod.str() << type.str() << " -> " << mapOp << " : " - << mapOp.getType() << ')'; - if (i + 1 < e) + IntegerAttr mapType) { + uint64_t mapTypeBits = mapType.getUInt(); + + bool emitAllocRelease = true; + llvm::SmallVector mapTypeStrs; + + // handling of always, close, present placed at the beginning of the string + // to aid readability + if (mapTypeToBitFlag(mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS)) + mapTypeStrs.push_back("always"); + if (mapTypeToBitFlag(mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE)) + mapTypeStrs.push_back("close"); + if (mapTypeToBitFlag(mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) + mapTypeStrs.push_back("present"); + + // special handling of to/from/tofrom/delete and release/alloc, release + + // alloc are the abscense of one of the other flags, whereas tofrom requires + // both the to and from flag to be set. + bool to = mapTypeToBitFlag(mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); + bool from = mapTypeToBitFlag( + mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); + if (to && from) { + emitAllocRelease = false; + mapTypeStrs.push_back("tofrom"); + } else if (from) { + emitAllocRelease = false; + mapTypeStrs.push_back("from"); + } else if (to) { + emitAllocRelease = false; + mapTypeStrs.push_back("to"); + } + if (mapTypeToBitFlag(mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) { + emitAllocRelease = false; + mapTypeStrs.push_back("delete"); + } + if (emitAllocRelease) + mapTypeStrs.push_back("exit_release_or_enter_alloc"); + + // handling of non-specification runtime related flags + if (mapTypeToBitFlag( + mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ)) + mapTypeStrs.push_back("ptr_and_obj"); + if (mapTypeToBitFlag( + mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM)) + mapTypeStrs.push_back("target_param"); + if (mapTypeToBitFlag( + mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) + mapTypeStrs.push_back("return_param"); + if (mapTypeToBitFlag(mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE)) + mapTypeStrs.push_back("private"); + if (mapTypeToBitFlag(mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL)) + mapTypeStrs.push_back("literal"); + if (mapTypeToBitFlag(mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT)) + mapTypeStrs.push_back("implicit"); + if (mapTypeToBitFlag(mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD)) + mapTypeStrs.push_back("ompx_hold"); + if (mapTypeToBitFlag( + mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NON_CONTIG)) + mapTypeStrs.push_back("non_contig"); + + if (mapTypeToBitFlag(mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF)) + mapTypeStrs.push_back("member_of"); + + for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) { + p << mapTypeStrs[i]; + if (i + 1 < mapTypeStrs.size()) { p << ", "; + } } } -static LogicalResult verifyMapClause(Operation *op, OperandRange map_operands, - std::optional map_types) { - // Helper function to get bitwise AND of `value` and 'flag' - auto bitAnd = [](int64_t value, - llvm::omp::OpenMPOffloadMappingFlags flag) -> bool { - return value & - static_cast< - std::underlying_type_t>( - flag); - }; - if (!map_types) { - if (!map_operands.empty()) - return emitError(op->getLoc(), "missing mapTypes"); - else - return success(); - } +static void printCaptureType(OpAsmPrinter &p, Operation *op, + VariableCaptureKindAttr mapCaptureType) { + std::string typeCapStr; + llvm::raw_string_ostream typeCap(typeCapStr); + if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef) + typeCap << "ByRef"; + if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy) + typeCap << "ByCopy"; + if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType) + typeCap << "VLAType"; + if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This) + typeCap << "This"; + p << typeCap.str(); +} + +static ParseResult parseCaptureType(OpAsmParser &parser, + VariableCaptureKindAttr &mapCapture) { + StringRef mapCaptureKey; + if (parser.parseKeyword(&mapCaptureKey)) + return failure(); - if (map_operands.empty() && !map_types->empty()) - return emitError(op->getLoc(), "missing mapOperands"); + if (mapCaptureKey == "This") + mapCapture = mlir::omp::VariableCaptureKindAttr::get( + parser.getContext(), mlir::omp::VariableCaptureKind::This); + if (mapCaptureKey == "ByRef") + mapCapture = mlir::omp::VariableCaptureKindAttr::get( + parser.getContext(), mlir::omp::VariableCaptureKind::ByRef); + if (mapCaptureKey == "ByCopy") + mapCapture = mlir::omp::VariableCaptureKindAttr::get( + parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy); + if (mapCaptureKey == "VLAType") + mapCapture = mlir::omp::VariableCaptureKindAttr::get( + parser.getContext(), mlir::omp::VariableCaptureKind::VLAType); - if (map_types->empty() && !map_operands.empty()) - return emitError(op->getLoc(), "missing mapTypes"); + return success(); +} - if (map_operands.size() != map_types->size()) - return emitError(op->getLoc(), - "mismatch in number of mapOperands and mapTypes"); +static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) { - for (const auto &mapTypeOp : *map_types) { - int64_t mapTypeBits = 0x00; + for (auto mapOp : mapOperands) { + if (!mapOp.getDefiningOp()) + emitError(op->getLoc(), "missing map operation"); - if (!llvm::isa(mapTypeOp)) - return failure(); + if (auto mapEntryOp = + mlir::dyn_cast(mapOp.getDefiningOp())) { + + if (!mapEntryOp.getMapType().has_value()) + emitError(op->getLoc(), "missing map type for map operand"); + + if (!mapEntryOp.getMapCaptureType().has_value()) + emitError(op->getLoc(), "missing map capture type for map operand"); + + uint64_t mapTypeBits = mapEntryOp.getMapType().value(); - mapTypeBits = llvm::cast(mapTypeOp).getInt(); - - bool to = - bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); - bool from = - bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); - bool del = bitAnd(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE); - - if ((isa(op) || isa(op)) && del) - return emitError(op->getLoc(), - "to, from, tofrom and alloc map types are permitted"); - if (isa(op) && (from || del)) - return emitError(op->getLoc(), "to and alloc map types are permitted"); - if (isa(op) && to) - return emitError(op->getLoc(), - "from, release and delete map types are permitted"); + bool to = mapTypeToBitFlag( + mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); + bool from = mapTypeToBitFlag( + mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); + bool del = mapTypeToBitFlag( + mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE); + + if ((isa(op) || isa(op)) && del) + return emitError(op->getLoc(), + "to, from, tofrom and alloc map types are permitted"); + + if (isa(op) && (from || del)) + return emitError(op->getLoc(), "to and alloc map types are permitted"); + + if (isa(op) && to) + return emitError(op->getLoc(), + "from, release and delete map types are permitted"); + } else { + emitError(op->getLoc(), "map argument is not a map entry operation"); + } } return success(); @@ -852,19 +929,19 @@ return ::emitError(this->getLoc(), "At least one of map, useDevicePtr, or " "useDeviceAddr operand must be present"); } - return verifyMapClause(*this, getMapOperands(), getMapTypes()); + return verifyMapClause(*this, getMapOperands()); } LogicalResult EnterDataOp::verify() { - return verifyMapClause(*this, getMapOperands(), getMapTypes()); + return verifyMapClause(*this, getMapOperands()); } LogicalResult ExitDataOp::verify() { - return verifyMapClause(*this, getMapOperands(), getMapTypes()); + return verifyMapClause(*this, getMapOperands()); } LogicalResult TargetOp::verify() { - return verifyMapClause(*this, getMapOperands(), getMapTypes()); + return verifyMapClause(*this, getMapOperands()); } //===----------------------------------------------------------------------===// @@ -1455,8 +1532,23 @@ return success(); } +//===----------------------------------------------------------------------===// +// DataBoundsOp +//===----------------------------------------------------------------------===// + +LogicalResult DataBoundsOp::verify() { + auto extent = getExtent(); + auto upperbound = getUpperBound(); + if (!extent && !upperbound) + return emitError("expected extent or upperbound."); + return success(); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc" \ No newline at end of file diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -193,13 +193,26 @@ // CHECK-LABEL: @_QPomp_target_data // CHECK: (%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: !llvm.ptr, %[[ARG3:.*]]: !llvm.ptr) -// CHECK: omp.target_enter_data map((to -> %[[ARG0]] : !llvm.ptr), (to -> %[[ARG1]] : !llvm.ptr), (always, alloc -> %[[ARG2]] : !llvm.ptr)) -// CHECK: omp.target_exit_data map((from -> %[[ARG0]] : !llvm.ptr), (from -> %[[ARG1]] : !llvm.ptr), (release -> %[[ARG2]] : !llvm.ptr), (always, delete -> %[[ARG3]] : !llvm.ptr)) -// CHECK: llvm.return +// CHECK: %[[MAP0:.*]] = omp.map_entry var_ptr(%[[ARG0]] : !llvm.ptr) map_clauses(to, target_param) capture(ByRef) -> !llvm.ptr {name = ""} +// CHECK: %[[MAP1:.*]] = omp.map_entry var_ptr(%[[ARG1]] : !llvm.ptr) map_clauses(to, target_param) capture(ByRef) -> !llvm.ptr {name = ""} +// CHECK: %[[MAP2:.*]] = omp.map_entry var_ptr(%[[ARG2]] : !llvm.ptr) map_clauses(always, exit_release_or_enter_alloc, target_param) capture(ByRef) -> !llvm.ptr {name = ""} +// CHECK: omp.target_enter_data map_entries(%[[MAP0]], %[[MAP1]], %[[MAP2]] : !llvm.ptr, !llvm.ptr, !llvm.ptr) +// CHECK: %[[MAP3:.*]] = omp.map_entry var_ptr(%[[ARG0]] : !llvm.ptr) map_clauses(from, target_param) capture(ByRef) -> !llvm.ptr {name = ""} +// CHECK: %[[MAP4:.*]] = omp.map_entry var_ptr(%[[ARG1]] : !llvm.ptr) map_clauses(from, target_param) capture(ByRef) -> !llvm.ptr {name = ""} +// CHECK: %[[MAP5:.*]] = omp.map_entry var_ptr(%[[ARG2]] : !llvm.ptr) map_clauses(exit_release_or_enter_alloc, target_param) capture(ByRef) -> !llvm.ptr {name = ""} +// CHECK: %[[MAP6:.*]] = omp.map_entry var_ptr(%[[ARG3]] : !llvm.ptr) map_clauses(always, delete, target_param) capture(ByRef) -> !llvm.ptr {name = ""} +// CHECK: omp.target_exit_data map_entries(%[[MAP3]], %[[MAP4]], %[[MAP5]], %[[MAP6]] : !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) llvm.func @_QPomp_target_data(%a : !llvm.ptr, %b : !llvm.ptr, %c : !llvm.ptr, %d : !llvm.ptr) { - omp.target_enter_data map((to -> %a : !llvm.ptr), (to -> %b : !llvm.ptr), (always, alloc -> %c : !llvm.ptr)) - omp.target_exit_data map((from -> %a : !llvm.ptr), (from -> %b : !llvm.ptr), (release -> %c : !llvm.ptr), (always, delete -> %d : !llvm.ptr)) + %0 = omp.map_entry var_ptr(%a : !llvm.ptr) map_clauses(to, target_param) capture(ByRef) -> !llvm.ptr {name = ""} + %1 = omp.map_entry var_ptr(%b : !llvm.ptr) map_clauses(to, target_param) capture(ByRef) -> !llvm.ptr {name = ""} + %2 = omp.map_entry var_ptr(%c : !llvm.ptr) map_clauses(always, exit_release_or_enter_alloc, target_param) capture(ByRef) -> !llvm.ptr {name = ""} + omp.target_enter_data map_entries(%0, %1, %2 : !llvm.ptr, !llvm.ptr, !llvm.ptr) {} + %3 = omp.map_entry var_ptr(%a : !llvm.ptr) map_clauses(from, target_param) capture(ByRef) -> !llvm.ptr {name = ""} + %4 = omp.map_entry var_ptr(%b : !llvm.ptr) map_clauses(from, target_param) capture(ByRef) -> !llvm.ptr {name = ""} + %5 = omp.map_entry var_ptr(%c : !llvm.ptr) map_clauses(exit_release_or_enter_alloc, target_param) capture(ByRef) -> !llvm.ptr {name = ""} + %6 = omp.map_entry var_ptr(%d : !llvm.ptr) map_clauses(always, delete, target_param) capture(ByRef) -> !llvm.ptr {name = ""} + omp.target_exit_data map_entries(%3, %4, %5, %6 : !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) {} llvm.return } @@ -207,7 +220,8 @@ // CHECK-LABEL: @_QPomp_target_data_region // CHECK: (%[[ARG0:.*]]: !llvm.ptr>, %[[ARG1:.*]]: !llvm.ptr) { -// CHECK: omp.target_data map((tofrom -> %[[ARG0]] : !llvm.ptr>)) { +// CHECK: %[[MAP_0:.*]] = omp.map_entry var_ptr(%[[ARG0]] : !llvm.ptr>) map_clauses(tofrom, target_param) capture(ByRef) -> !llvm.ptr> {name = ""} +// CHECK: omp.target_data map_entries(%[[MAP_0]] : !llvm.ptr>) { // CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(10 : i32) : i32 // CHECK: llvm.store %[[VAL_1]], %[[ARG1]] : !llvm.ptr // CHECK: omp.terminator @@ -215,9 +229,10 @@ // CHECK: llvm.return llvm.func @_QPomp_target_data_region(%a : !llvm.ptr>, %i : !llvm.ptr) { - omp.target_data map((tofrom -> %a : !llvm.ptr>)) { - %1 = llvm.mlir.constant(10 : i32) : i32 - llvm.store %1, %i : !llvm.ptr + %1 = omp.map_entry var_ptr(%a : !llvm.ptr>) map_clauses(tofrom, target_param) capture(ByRef) -> !llvm.ptr> {name = ""} + omp.target_data map_entries(%1 : !llvm.ptr>) { + %2 = llvm.mlir.constant(10 : i32) : i32 + llvm.store %2, %i : !llvm.ptr omp.terminator } llvm.return @@ -229,7 +244,8 @@ // CHECK: %[[ARG_0:.*]]: !llvm.ptr>, // CHECK: %[[ARG_1:.*]]: !llvm.ptr) { // CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(64 : i32) : i32 -// CHECK: omp.target thread_limit(%[[VAL_0]] : i32) map((tofrom -> %[[ARG_0]] : !llvm.ptr>)) { +// CHECK: %[[MAP:.*]] = omp.map_entry var_ptr(%[[ARG_0]] : !llvm.ptr>) map_clauses(tofrom, target_param) capture(ByRef) -> !llvm.ptr> {name = ""} +// CHECK: omp.target thread_limit(%[[VAL_0]] : i32) map_entries(%[[MAP]] : !llvm.ptr>) { // CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(10 : i32) : i32 // CHECK: llvm.store %[[VAL_1]], %[[ARG_1]] : !llvm.ptr // CHECK: omp.terminator @@ -239,9 +255,10 @@ llvm.func @_QPomp_target(%a : !llvm.ptr>, %i : !llvm.ptr) { %0 = llvm.mlir.constant(64 : i32) : i32 - omp.target thread_limit(%0 : i32) map((tofrom -> %a : !llvm.ptr>)) { - %1 = llvm.mlir.constant(10 : i32) : i32 - llvm.store %1, %i : !llvm.ptr + %1 = omp.map_entry var_ptr(%a : !llvm.ptr>) map_clauses(tofrom, target_param) capture(ByRef) -> !llvm.ptr> {name = ""} + omp.target thread_limit(%0 : i32) map_entries(%1 : !llvm.ptr>) { + %2 = llvm.mlir.constant(10 : i32) : i32 + llvm.store %2, %i : !llvm.ptr omp.terminator } llvm.return @@ -384,3 +401,46 @@ llvm.func @_QFPdo_work(%arg0: !llvm.ptr {fir.bindc_name = "i"}) { llvm.return } + +// ----- + +// CHECK-LABEL: llvm.func @_QPtarget_map_with_bounds( +// CHECK: %[[ARG_0:.*]]: !llvm.ptr, +// CHECK: %[[ARG_1:.*]]: !llvm.ptr>, +// CHECK: %[[ARG_2:.*]]: !llvm.ptr>) { +// CHECK: %[[C_01:.*]] = llvm.mlir.constant(4 : index) : i64 +// CHECK: %[[C_02:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[C_03:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[C_04:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[BOUNDS0:.*]] = omp.bounds lower_bound(%[[C_02]] : i64) upper_bound(%[[C_01]] : i64) stride(%[[C_04]] : i64) start_idx(%[[C_04]] : i64) +// CHECK: %[[MAP0:.*]] = omp.map_entry var_ptr(%[[ARG_1]] : !llvm.ptr>) map_clauses(tofrom, target_param) capture(ByRef) bounds(%[[BOUNDS0]]) -> !llvm.ptr> {name = ""} +// CHECK: %[[C_11:.*]] = llvm.mlir.constant(4 : index) : i64 +// CHECK: %[[C_12:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[C_13:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[C_14:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[BOUNDS1:.*]] = omp.bounds lower_bound(%[[C_12]] : i64) upper_bound(%[[C_11]] : i64) stride(%[[C_14]] : i64) start_idx(%[[C_14]] : i64) +// CHECK: %[[MAP1:.*]] = omp.map_entry var_ptr(%[[ARG_2]] : !llvm.ptr>) map_clauses(tofrom, target_param) capture(ByRef) bounds(%[[BOUNDS1]]) -> !llvm.ptr> {name = ""} +// CHECK: omp.target map_entries(%[[MAP0]], %[[MAP1]] : !llvm.ptr>, !llvm.ptr>) { +// CHECK: omp.terminator +// CHECK: } +// CHECK: llvm.return +// CHECK:} + +llvm.func @_QPtarget_map_with_bounds(%arg0: !llvm.ptr, %arg1: !llvm.ptr>, %arg2: !llvm.ptr>) { + %0 = llvm.mlir.constant(4 : index) : i64 + %1 = llvm.mlir.constant(1 : index) : i64 + %2 = llvm.mlir.constant(1 : index) : i64 + %3 = llvm.mlir.constant(1 : index) : i64 + %4 = omp.bounds lower_bound(%1 : i64) upper_bound(%0 : i64) stride(%3 : i64) start_idx(%3 : i64) + %5 = omp.map_entry var_ptr(%arg1 : !llvm.ptr>) map_clauses(tofrom, target_param) capture(ByRef) bounds(%4) -> !llvm.ptr> {name = ""} + %6 = llvm.mlir.constant(4 : index) : i64 + %7 = llvm.mlir.constant(1 : index) : i64 + %8 = llvm.mlir.constant(1 : index) : i64 + %9 = llvm.mlir.constant(1 : index) : i64 + %10 = omp.bounds lower_bound(%7 : i64) upper_bound(%6 : i64) stride(%9 : i64) start_idx(%9 : i64) + %11 = omp.map_entry var_ptr(%arg2 : !llvm.ptr>) map_clauses(tofrom, target_param) capture(ByRef) bounds(%10) -> !llvm.ptr> {name = ""} + omp.target map_entries(%5, %11 : !llvm.ptr>, !llvm.ptr>) { + omp.terminator + } + llvm.return +} diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1615,16 +1615,18 @@ // ----- func.func @omp_target(%map1: memref) { + %mapv = omp.map_entry var_ptr(%map1 : memref) map_clauses(delete, target_param) capture(ByRef) -> memref {name = ""} // expected-error @below {{to, from, tofrom and alloc map types are permitted}} - omp.target map((delete -> %map1 : memref)){} + omp.target map_entries(%mapv : memref){} return } // ----- func.func @omp_target_data(%map1: memref) { + %mapv = omp.map_entry var_ptr(%map1 : memref) map_clauses(delete, target_param) capture(ByRef) -> memref {name = ""} // expected-error @below {{to, from, tofrom and alloc map types are permitted}} - omp.target_data map((delete -> %map1 : memref)){} + omp.target_data map_entries(%mapv : memref){} return } @@ -1639,16 +1641,18 @@ // ----- func.func @omp_target_enter_data(%map1: memref) { + %mapv = omp.map_entry var_ptr(%map1 : memref) map_clauses(from, target_param) capture(ByRef) -> memref {name = ""} // expected-error @below {{to and alloc map types are permitted}} - omp.target_enter_data map((from -> %map1 : memref)){} + omp.target_enter_data map_entries(%mapv : memref){} return } // ----- func.func @omp_target_exit_data(%map1: memref) { + %mapv = omp.map_entry var_ptr(%map1 : memref) map_clauses(to, target_param) capture(ByRef) -> memref {name = ""} // expected-error @below {{from, release and delete map types are permitted}} - omp.target_exit_data map((to -> %map1 : memref)){} + omp.target_exit_data map_entries(%mapv : memref){} return } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -490,12 +490,18 @@ }) {nowait, operandSegmentSizes = array} : ( i1, si32, i32 ) -> () // Test with optional map clause. - // CHECK: omp.target map((tofrom -> %{{.*}} : memref), (alloc -> %{{.*}} : memref)) { - omp.target map((tofrom -> %map1 : memref), (alloc -> %map2 : memref)){} - - // CHECK: omp.target map((to -> %{{.*}} : memref), (always, from -> %{{.*}} : memref)) { - omp.target map((to -> %map1 : memref), (always, from -> %map2 : memref)){} - + // CHECK: %[[MAP_A:.*]] = omp.map_entry var_ptr(%[[VAL_1:.*]] : memref) map_clauses(tofrom, target_param) capture(ByRef) -> memref {name = ""} + // CHECK: %[[MAP_B:.*]] = omp.map_entry var_ptr(%[[VAL_2:.*]] : memref) map_clauses(exit_release_or_enter_alloc, target_param) capture(ByRef) -> memref {name = ""} + // CHECK: omp.target map_entries(%[[MAP_A]], %[[MAP_B]] : memref, memref) { + %mapv1 = omp.map_entry var_ptr(%map1 : memref) map_clauses(tofrom, target_param) capture(ByRef) -> memref {name = ""} + %mapv2 = omp.map_entry var_ptr(%map2 : memref) map_clauses(exit_release_or_enter_alloc, target_param) capture(ByRef) -> memref {name = ""} + omp.target map_entries(%mapv1, %mapv2 : memref, memref){} + // CHECK: %[[MAP_C:.*]] = omp.map_entry var_ptr(%[[VAL_1:.*]] : memref) map_clauses(to, target_param) capture(ByRef) -> memref {name = ""} + // CHECK: %[[MAP_D:.*]] = omp.map_entry var_ptr(%[[VAL_2:.*]] : memref) map_clauses(always, from, target_param) capture(ByRef) -> memref {name = ""} + // CHECK: omp.target map_entries(%[[MAP_C]], %[[MAP_D]] : memref, memref) { + %mapv3 = omp.map_entry var_ptr(%map1 : memref) map_clauses(to, target_param) capture(ByRef) -> memref {name = ""} + %mapv4 = omp.map_entry var_ptr(%map2 : memref) map_clauses(always, from, target_param) capture(ByRef) -> memref {name = ""} + omp.target map_entries(%mapv3, %mapv4 : memref, memref) {} // CHECK: omp.barrier omp.barrier @@ -504,21 +510,33 @@ // CHECK-LABEL: omp_target_data func.func @omp_target_data (%if_cond : i1, %device : si32, %device_ptr: memref, %device_addr: memref, %map1: memref, %map2: memref) -> () { - // CHECK: omp.target_data if(%[[VAL_0:.*]] : i1) device(%[[VAL_1:.*]] : si32) map((always, from -> %[[VAL_2:.*]] : memref)) - omp.target_data if(%if_cond : i1) device(%device : si32) map((always, from -> %map1 : memref)){} - - // CHECK: omp.target_data map((close, present, to -> %[[VAL_2:.*]] : memref)) use_device_ptr(%[[VAL_3:.*]] : memref) use_device_addr(%[[VAL_4:.*]] : memref) - omp.target_data map((close, present, to -> %map1 : memref)) use_device_ptr(%device_ptr : memref) use_device_addr(%device_addr : memref) {} - - // CHECK: omp.target_data map((tofrom -> %[[VAL_2]] : memref), (alloc -> %[[VAL_5:.*]] : memref)) - omp.target_data map((tofrom -> %map1 : memref), (alloc -> %map2 : memref)){} - - // CHECK: omp.target_enter_data if(%[[VAL_0]] : i1) device(%[[VAL_1]] : si32) nowait map((alloc -> %[[VAL_2]] : memref)) - omp.target_enter_data if(%if_cond : i1) device(%device : si32) nowait map((alloc -> %map1 : memref)) - - // CHECK: omp.target_exit_data if(%[[VAL_0]] : i1) device(%[[VAL_1]] : si32) nowait map((release -> %[[VAL_5]] : memref)) - omp.target_exit_data if(%if_cond : i1) device(%device : si32) nowait map((release -> %map2 : memref)) - + // CHECK: %[[MAP_A:.*]] = omp.map_entry var_ptr(%[[VAL_2:.*]] : memref) map_clauses(always, from, target_param) capture(ByRef) -> memref {name = ""} + // CHECK: omp.target_data if(%[[VAL_0:.*]] : i1) device(%[[VAL_1:.*]] : si32) map_entries(%[[MAP_A]] : memref) + %mapv1 = omp.map_entry var_ptr(%map1 : memref) map_clauses(always, from, target_param) capture(ByRef) -> memref {name = ""} + omp.target_data if(%if_cond : i1) device(%device : si32) map_entries(%mapv1 : memref){} + + // CHECK: %[[MAP_A:.*]] = omp.map_entry var_ptr(%[[VAL_2:.*]] : memref) map_clauses(close, present, to, target_param) capture(ByRef) -> memref {name = ""} + // CHECK: omp.target_data map_entries(%[[MAP_A]] : memref) use_device_ptr(%[[VAL_3:.*]] : memref) use_device_addr(%[[VAL_4:.*]] : memref) + %mapv2 = omp.map_entry var_ptr(%map1 : memref) map_clauses(close, present, to, target_param) capture(ByRef) -> memref {name = ""} + omp.target_data map_entries(%mapv2 : memref) use_device_ptr(%device_ptr : memref) use_device_addr(%device_addr : memref) {} + + // CHECK: %[[MAP_A:.*]] = omp.map_entry var_ptr(%[[VAL_1:.*]] : memref) map_clauses(tofrom, target_param) capture(ByRef) -> memref {name = ""} + // CHECK: %[[MAP_B:.*]] = omp.map_entry var_ptr(%[[VAL_2:.*]] : memref) map_clauses(exit_release_or_enter_alloc, target_param) capture(ByRef) -> memref {name = ""} + // CHECK: omp.target_data map_entries(%[[MAP_A]], %[[MAP_B]] : memref, memref) + %mapv3 = omp.map_entry var_ptr(%map1 : memref) map_clauses(tofrom, target_param) capture(ByRef) -> memref {name = ""} + %mapv4 = omp.map_entry var_ptr(%map2 : memref) map_clauses(exit_release_or_enter_alloc, target_param) capture(ByRef) -> memref {name = ""} + omp.target_data map_entries(%mapv3, %mapv4 : memref, memref) {} + + // CHECK: %[[MAP_A:.*]] = omp.map_entry var_ptr(%[[VAL_3:.*]] : memref) map_clauses(exit_release_or_enter_alloc, target_param) capture(ByRef) -> memref {name = ""} + // CHECK: omp.target_enter_data if(%[[VAL_0:.*]] : i1) device(%[[VAL_1:.*]] : si32) nowait map_entries(%[[MAP_A]] : memref) + %mapv5 = omp.map_entry var_ptr(%map1 : memref) map_clauses(exit_release_or_enter_alloc, target_param) capture(ByRef) -> memref {name = ""} + omp.target_enter_data if(%if_cond : i1) device(%device : si32) nowait map_entries(%mapv5 : memref) + + // CHECK: %[[MAP_A:.*]] = omp.map_entry var_ptr(%[[VAL_3:.*]] : memref) map_clauses(exit_release_or_enter_alloc, target_param) capture(ByRef) -> memref {name = ""} + // CHECK: omp.target_exit_data if(%[[VAL_0:.*]] : i1) device(%[[VAL_1:.*]] : si32) nowait map_entries(%[[MAP_A]] : memref) + %mapv6 = omp.map_entry var_ptr(%map2 : memref) map_clauses(exit_release_or_enter_alloc, target_param) capture(ByRef) -> memref {name = ""} + omp.target_exit_data if(%if_cond : i1) device(%device : si32) nowait map_entries(%mapv6 : memref) + return } @@ -2007,3 +2025,51 @@ llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32 omp.yield } + +// CHECK-LABEL: omp_targets_with_map_bounds +// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr>, %[[ARG1:.*]]: !llvm.ptr>) +func.func @omp_targets_with_map_bounds(%arg0: !llvm.ptr>, %arg1: !llvm.ptr>) -> () { + // CHECK: %[[C_00:.*]] = llvm.mlir.constant(4 : index) : i64 + // CHECK: %[[C_01:.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK: %[[C_02:.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK: %[[C_03:.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK: %[[BOUNDS0:.*]] = omp.bounds lower_bound(%[[C_01]] : i64) upper_bound(%[[C_00]] : i64) stride(%[[C_02]] : i64) start_idx(%[[C_03]] : i64) + // CHECK: %[[MAP0:.*]] = omp.map_entry var_ptr(%[[ARG0]] : !llvm.ptr>) map_clauses(tofrom, target_param, non_contig, member_of) capture(ByRef) bounds(%[[BOUNDS0]]) -> !llvm.ptr> {name = ""} + %0 = llvm.mlir.constant(4 : index) : i64 + %1 = llvm.mlir.constant(1 : index) : i64 + %2 = llvm.mlir.constant(1 : index) : i64 + %3 = llvm.mlir.constant(1 : index) : i64 + %4 = omp.bounds lower_bound(%1 : i64) upper_bound(%0 : i64) stride(%2 : i64) start_idx(%3 : i64) + + %mapv1 = omp.map_entry var_ptr(%arg0 : !llvm.ptr>) map_clauses(tofrom, target_param, non_contig, member_of) capture(ByRef) bounds(%4) -> !llvm.ptr> {name = ""} + // CHECK: %[[C_10:.*]] = llvm.mlir.constant(9 : index) : i64 + // CHECK: %[[C_11:.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK: %[[C_12:.*]] = llvm.mlir.constant(2 : index) : i64 + // CHECK: %[[C_13:.*]] = llvm.mlir.constant(2 : index) : i64 + // CHECK: %[[BOUNDS1:.*]] = omp.bounds lower_bound(%[[C_11]] : i64) upper_bound(%[[C_10]] : i64) stride(%[[C_12]] : i64) start_idx(%[[C_13]] : i64) + // CHECK: %[[MAP1:.*]] = omp.map_entry var_ptr(%[[ARG1]] : !llvm.ptr>) map_clauses(exit_release_or_enter_alloc, target_param, implicit, ompx_hold) capture(ByCopy) bounds(%[[BOUNDS1]]) -> !llvm.ptr> {name = ""} + %6 = llvm.mlir.constant(9 : index) : i64 + %7 = llvm.mlir.constant(1 : index) : i64 + %8 = llvm.mlir.constant(2 : index) : i64 + %9 = llvm.mlir.constant(2 : index) : i64 + %10 = omp.bounds lower_bound(%7 : i64) upper_bound(%6 : i64) stride(%8 : i64) start_idx(%9 : i64) + %mapv2 = omp.map_entry var_ptr(%arg1 : !llvm.ptr>) map_clauses(exit_release_or_enter_alloc, target_param, implicit, ompx_hold) capture(ByCopy) bounds(%10) -> !llvm.ptr> {name = ""} + + // CHECK: omp.target map_entries(%[[MAP0]], %[[MAP1]] : !llvm.ptr>, !llvm.ptr>) + omp.target map_entries(%mapv1, %mapv2 : !llvm.ptr>, !llvm.ptr>){} + + // CHECK: omp.target_data map_entries(%[[MAP0]], %[[MAP1]] : !llvm.ptr>, !llvm.ptr>) + omp.target_data map_entries(%mapv1, %mapv2 : !llvm.ptr>, !llvm.ptr>){} + + // CHECK: %[[MAP2:.*]] = omp.map_entry var_ptr(%[[ARG0]] : !llvm.ptr>) map_clauses(exit_release_or_enter_alloc, target_param, private, literal) capture(VLAType) bounds(%[[BOUNDS0]]) -> !llvm.ptr> {name = ""} + // CHECK: omp.target_enter_data map_entries(%[[MAP2]] : !llvm.ptr>) + %mapv3 = omp.map_entry var_ptr(%arg0 : !llvm.ptr>) map_clauses(exit_release_or_enter_alloc, target_param, private, literal) capture(VLAType) bounds(%4) -> !llvm.ptr> {name = ""} + omp.target_enter_data map_entries(%mapv3 : !llvm.ptr>){} + + // CHECK: %[[MAP3:.*]] = omp.map_entry var_ptr(%[[ARG1]] : !llvm.ptr>) map_clauses(exit_release_or_enter_alloc, ptr_and_obj, target_param, return_param) capture(This) bounds(%[[BOUNDS1]]) -> !llvm.ptr> {name = ""} + // CHECK: omp.target_exit_data map_entries(%[[MAP3]] : !llvm.ptr>) + %mapv4 = omp.map_entry var_ptr(%arg1 : !llvm.ptr>) map_clauses(exit_release_or_enter_alloc, ptr_and_obj, target_param, return_param) capture(This) bounds(%10) -> !llvm.ptr> {name = ""} + omp.target_exit_data map_entries(%mapv4 : !llvm.ptr>){} + + return +}