diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -1540,27 +1540,34 @@ unsigned index = 0; for (const auto &mapOp : mapOperands) { + // Unlike dev_ptr and dev_addr operands these map operands point + // to a map entry operation which contains further information + // on the variable being mapped and how it should be mapped. + auto MapInfoOp = + mlir::dyn_cast(mapOp.getDefiningOp()); + // TODO: Only LLVMPointerTypes are handled. - if (!mapOp.getType().isa()) + if (!MapInfoOp.getType().isa()) return fail(); - llvm::Value *mapOpValue = moduleTranslation.lookupValue(mapOp); + llvm::Value *mapOpValue = + moduleTranslation.lookupValue(MapInfoOp.getVarPtr()); combinedInfo.BasePointers.emplace_back(mapOpValue); combinedInfo.Pointers.emplace_back(mapOpValue); combinedInfo.DevicePointers.emplace_back( llvm::OpenMPIRBuilder::DeviceInfoTy::None); - combinedInfo.Names.emplace_back( - LLVM::createMappingInformation(mapOp.getLoc(), *ompBuilder)); + combinedInfo.Names.emplace_back(LLVM::createMappingInformation( + MapInfoOp.getVarPtr().getLoc(), *ompBuilder)); combinedInfo.Types.emplace_back( llvm::omp::OpenMPOffloadMappingFlags( - mapTypes[index].dyn_cast().getInt()) | + mapTypes[index].dyn_cast().getUInt()) | (IsTargetParams ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE)); combinedInfo.Sizes.emplace_back( - builder.getInt64(getSizeInBytes(DL, mapOp.getType()))); + builder.getInt64(getSizeInBytes(DL, MapInfoOp.getVarPtr().getType()))); index++; } @@ -1609,6 +1616,19 @@ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + auto getMapTypes = [](mlir::OperandRange mapOperands, + mlir::MLIRContext *ctx) { + SmallVector mapTypes; + for (auto mapValue : mapOperands) { + if (mapValue.getDefiningOp()) { + auto mapOp = + mlir::dyn_cast(mapValue.getDefiningOp()); + mapTypes.push_back(mapOp.getMapTypeAttr()); + } + } + return mlir::ArrayAttr::get(ctx, mapTypes); + }; + LogicalResult result = llvm::TypeSwitch(op) .Case([&](omp::DataOp dataOp) { @@ -1622,8 +1642,8 @@ deviceID = intAttr.getInt(); mapOperands = dataOp.getMapOperands(); - if (dataOp.getMapTypes()) - mapTypes = dataOp.getMapTypes().value(); + mapTypes = + getMapTypes(dataOp.getMapOperands(), dataOp->getContext()); useDevPtrOperands = dataOp.getUseDevicePtr(); useDevAddrOperands = dataOp.getUseDeviceAddr(); return success(); @@ -1642,7 +1662,8 @@ deviceID = intAttr.getInt(); RTLFn = llvm::omp::OMPRTL___tgt_target_data_begin_mapper; mapOperands = enterDataOp.getMapOperands(); - mapTypes = enterDataOp.getMapTypes(); + mapTypes = getMapTypes(enterDataOp.getMapOperands(), + enterDataOp->getContext()); return success(); }) .Case([&](omp::ExitDataOp exitDataOp) { @@ -1660,7 +1681,8 @@ RTLFn = llvm::omp::OMPRTL___tgt_target_data_end_mapper; mapOperands = exitDataOp.getMapOperands(); - mapTypes = exitDataOp.getMapTypes(); + mapTypes = getMapTypes(exitDataOp.getMapOperands(), + exitDataOp->getContext()); return success(); }) .Default([&](Operation *op) { @@ -1887,7 +1909,22 @@ DataLayout DL = DataLayout(opInst.getParentOfType()); SmallVector mapOperands = targetOp.getMapOperands(); - ArrayAttr mapTypes = targetOp.getMapTypes().value_or(nullptr); + + auto getMapTypes = [](mlir::OperandRange mapOperands, + mlir::MLIRContext *ctx) { + SmallVector mapTypes; + for (auto mapValue : mapOperands) { + if (mapValue.getDefiningOp()) { + auto mapOp = + mlir::dyn_cast(mapValue.getDefiningOp()); + mapTypes.push_back(mapOp.getMapTypeAttr()); + } + } + return mlir::ArrayAttr::get(ctx, mapTypes); + }; + + ArrayAttr mapTypes = + getMapTypes(targetOp.getMapOperands(), targetOp->getContext()); llvm::OpenMPIRBuilder::MapInfosTy combinedInfos; auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) @@ -2170,6 +2207,12 @@ .Case([&](omp::TargetOp) { return convertOmpTarget(*op, builder, moduleTranslation); }) + .Case([&](auto op) { + // No-op, should be handled by relevant owning operations e.g. + // TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then + // discarded + return success(); + }) .Default([&](Operation *inst) { return inst->emitError("unsupported OpenMP operation: ") << inst->getName(); diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir --- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir @@ -2,10 +2,11 @@ llvm.func @_QPopenmp_target_data() { %0 = llvm.mlir.constant(1 : i64) : i64 - %1 = llvm.alloca %0 x i32 {bindc_name = "i", in_type = i32, operandSegmentSizes = array, uniq_name = "_QFopenmp_target_dataEi"} : (i64) -> !llvm.ptr - omp.target_data map((tofrom -> %1 : !llvm.ptr)) { - %2 = llvm.mlir.constant(99 : i32) : i32 - llvm.store %2, %1 : !llvm.ptr + %1 = llvm.alloca %0 x i32 {bindc_name = "i", in_type = i32, operand_segment_sizes = array, uniq_name = "_QFopenmp_target_dataEi"} : (i64) -> !llvm.ptr + %2 = omp.map_info var_ptr(%1 : !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} + omp.target_data map_entries(%2 : !llvm.ptr) { + %3 = llvm.mlir.constant(99 : i32) : i32 + llvm.store %3, %1 : !llvm.ptr omp.terminator } llvm.return @@ -38,13 +39,14 @@ // ----- llvm.func @_QPopenmp_target_data_region(%1 : !llvm.ptr>) { - omp.target_data map((from -> %1 : !llvm.ptr>)) { - %2 = llvm.mlir.constant(99 : i32) : i32 - %3 = llvm.mlir.constant(1 : i64) : i64 + %2 = omp.map_info var_ptr(%1 : !llvm.ptr>) map_clauses(from) capture(ByRef) -> !llvm.ptr> {name = ""} + omp.target_data map_entries(%2 : !llvm.ptr>) { + %3 = llvm.mlir.constant(99 : i32) : i32 %4 = llvm.mlir.constant(1 : i64) : i64 - %5 = llvm.mlir.constant(0 : i64) : i64 - %6 = llvm.getelementptr %1[0, %5] : (!llvm.ptr>, i64) -> !llvm.ptr - llvm.store %2, %6 : !llvm.ptr + %5 = llvm.mlir.constant(1 : i64) : i64 + %6 = llvm.mlir.constant(0 : i64) : i64 + %7 = llvm.getelementptr %1[0, %6] : (!llvm.ptr>, i64) -> !llvm.ptr + llvm.store %3, %7 : !llvm.ptr omp.terminator } llvm.return @@ -90,12 +92,16 @@ %11 = llvm.mlir.constant(10 : i32) : i32 %12 = llvm.icmp "slt" %10, %11 : i32 %13 = llvm.load %5 : !llvm.ptr - omp.target_enter_data if(%12 : i1) device(%13 : i32) map((to -> %1 : !llvm.ptr>), (alloc -> %3 : !llvm.ptr>)) + %map1 = omp.map_info var_ptr(%1 : !llvm.ptr>) map_clauses(to) capture(ByRef) -> !llvm.ptr> {name = ""} + %map2 = omp.map_info var_ptr(%3 : !llvm.ptr>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !llvm.ptr> {name = ""} + omp.target_enter_data if(%12 : i1) device(%13 : i32) map_entries(%map1, %map2 : !llvm.ptr>, !llvm.ptr>) %14 = llvm.load %7 : !llvm.ptr %15 = llvm.mlir.constant(10 : i32) : i32 %16 = llvm.icmp "sgt" %14, %15 : i32 %17 = llvm.load %5 : !llvm.ptr - omp.target_exit_data if(%16 : i1) device(%17 : i32) map((from -> %1 : !llvm.ptr>), (release -> %3 : !llvm.ptr>)) + %map3 = omp.map_info var_ptr(%1 : !llvm.ptr>) map_clauses(from) capture(ByRef) -> !llvm.ptr> {name = ""} + %map4 = omp.map_info var_ptr(%3 : !llvm.ptr>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !llvm.ptr> {name = ""} + omp.target_exit_data if(%16 : i1) device(%17 : i32) map_entries(%map3, %map4 : !llvm.ptr>, !llvm.ptr>) llvm.return } @@ -172,7 +178,8 @@ llvm.func @_QPopenmp_target_use_dev_ptr() { %0 = llvm.mlir.constant(1 : i64) : i64 %a = llvm.alloca %0 x !llvm.ptr> : (i64) -> !llvm.ptr> - omp.target_data map((from -> %a : !llvm.ptr>)) use_device_ptr(%a : !llvm.ptr>) { + %map1 = omp.map_info var_ptr(%a : !llvm.ptr>) map_clauses(from) capture(ByRef) -> !llvm.ptr> {name = ""} + omp.target_data map_entries(%map1 : !llvm.ptr>) use_device_ptr(%a : !llvm.ptr>) { ^bb0(%arg0: !llvm.ptr>): %1 = llvm.mlir.constant(10 : i32) : i32 %2 = llvm.load %arg0 : !llvm.ptr> @@ -215,7 +222,8 @@ llvm.func @_QPopenmp_target_use_dev_addr() { %0 = llvm.mlir.constant(1 : i64) : i64 %a = llvm.alloca %0 x !llvm.ptr> : (i64) -> !llvm.ptr> - omp.target_data map((from -> %a : !llvm.ptr>)) use_device_addr(%a : !llvm.ptr>) { + %map = omp.map_info var_ptr(%a : !llvm.ptr>) map_clauses(from) capture(ByRef) -> !llvm.ptr> {name = ""} + omp.target_data map_entries(%map : !llvm.ptr>) use_device_addr(%a : !llvm.ptr>) { ^bb0(%arg0: !llvm.ptr>): %1 = llvm.mlir.constant(10 : i32) : i32 %2 = llvm.load %arg0 : !llvm.ptr> @@ -256,7 +264,8 @@ llvm.func @_QPopenmp_target_use_dev_addr_no_ptr() { %0 = llvm.mlir.constant(1 : i64) : i64 %a = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr - omp.target_data map((tofrom -> %a : !llvm.ptr)) use_device_addr(%a : !llvm.ptr) { + %map = omp.map_info var_ptr(%a : !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} + omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%a : !llvm.ptr) { ^bb0(%arg0: !llvm.ptr): %1 = llvm.mlir.constant(10 : i32) : i32 llvm.store %1, %arg0 : !llvm.ptr @@ -297,7 +306,8 @@ %a = llvm.alloca %0 x !llvm.ptr> : (i64) -> !llvm.ptr> %1 = llvm.mlir.constant(1 : i64) : i64 %b = llvm.alloca %0 x !llvm.ptr> : (i64) -> !llvm.ptr> - omp.target_data map((from -> %b : !llvm.ptr>)) use_device_addr(%a : !llvm.ptr>) { + %map = omp.map_info var_ptr(%b : !llvm.ptr>) map_clauses(from) capture(ByRef) -> !llvm.ptr> {name = ""} + omp.target_data map_entries(%map : !llvm.ptr>) use_device_addr(%a : !llvm.ptr>) { ^bb0(%arg0: !llvm.ptr>): %2 = llvm.mlir.constant(10 : i32) : i32 %3 = llvm.load %arg0 : !llvm.ptr> @@ -352,7 +362,9 @@ %a = llvm.alloca %0 x !llvm.ptr> : (i64) -> !llvm.ptr> %1 = llvm.mlir.constant(1 : i64) : i64 %b = llvm.alloca %0 x !llvm.ptr> : (i64) -> !llvm.ptr> - omp.target_data map((tofrom -> %a : !llvm.ptr>), (tofrom -> %b : !llvm.ptr>)) use_device_ptr(%a : !llvm.ptr>) use_device_addr(%b : !llvm.ptr>) { + %map = omp.map_info var_ptr(%a : !llvm.ptr>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr> {name = ""} + %map1 = omp.map_info var_ptr(%b : !llvm.ptr>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr> {name = ""} + omp.target_data map_entries(%map, %map1 : !llvm.ptr>, !llvm.ptr>) use_device_ptr(%a : !llvm.ptr>) use_device_addr(%b : !llvm.ptr>) { ^bb0(%arg0: !llvm.ptr>, %arg1: !llvm.ptr>): %2 = llvm.mlir.constant(10 : i32) : i32 %3 = llvm.load %arg0 : !llvm.ptr> diff --git a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir --- a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir @@ -12,7 +12,10 @@ %7 = llvm.alloca %6 x i32 {bindc_name = "c", in_type = i32, operandSegmentSizes = array, uniq_name = "_QFomp_target_regionEc"} : (i64) -> !llvm.ptr llvm.store %1, %3 : !llvm.ptr llvm.store %0, %5 : !llvm.ptr - omp.target map((to -> %3 : !llvm.ptr), (to -> %5 : !llvm.ptr), (from -> %7 : !llvm.ptr)) { + %map1 = omp.map_info var_ptr(%3 : !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} + %map2 = omp.map_info var_ptr(%5 : !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} + %map3 = omp.map_info var_ptr(%7 : !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} + omp.target map_entries(%map1, %map2, %map3 : !llvm.ptr, !llvm.ptr, !llvm.ptr) { %8 = llvm.load %3 : !llvm.ptr %9 = llvm.load %5 : !llvm.ptr %10 = llvm.add %8, %9 : i32 diff --git a/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir b/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir --- a/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir @@ -4,15 +4,17 @@ module attributes {omp.is_target_device = true} { llvm.func @writeindex_omp_outline_0_(%arg0: !llvm.ptr, %arg1: !llvm.ptr) attributes {omp.outline_parent_name = "writeindex_"} { - omp.target map((from -> %arg0 : !llvm.ptr), (implicit -> %arg1: !llvm.ptr)) { - %0 = llvm.mlir.constant(20 : i32) : i32 - %1 = llvm.mlir.constant(10 : i32) : i32 - llvm.store %1, %arg0 : !llvm.ptr - llvm.store %0, %arg1 : !llvm.ptr + %0 = omp.map_info var_ptr(%arg0 : !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} + %1 = omp.map_info var_ptr(%arg1 : !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} + omp.target map_entries(%0, %1 : !llvm.ptr, !llvm.ptr) { + %2 = llvm.mlir.constant(20 : i32) : i32 + %3 = llvm.mlir.constant(10 : i32) : i32 + llvm.store %3, %arg0 : !llvm.ptr + llvm.store %2, %arg1 : !llvm.ptr omp.terminator } llvm.return } } -// CHECK: define {{.*}} void @__omp_offloading_{{.*}}_{{.*}}_writeindex__l7(ptr {{.*}}, ptr {{.*}}) { +// CHECK: define {{.*}} void @__omp_offloading_{{.*}}_{{.*}}_writeindex__l{{.*}}(ptr {{.*}}, ptr {{.*}}) { diff --git a/mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir --- a/mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir @@ -12,7 +12,10 @@ %7 = llvm.alloca %6 x i32 {bindc_name = "c", in_type = i32, operandSegmentSizes = array, uniq_name = "_QFomp_target_regionEc"} : (i64) -> !llvm.ptr llvm.store %1, %3 : !llvm.ptr llvm.store %0, %5 : !llvm.ptr - omp.target map((to -> %3 : !llvm.ptr), (to -> %5 : !llvm.ptr), (from -> %7 : !llvm.ptr)) { + %map1 = omp.map_info var_ptr(%3 : !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} + %map2 = omp.map_info var_ptr(%5 : !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} + %map3 = omp.map_info var_ptr(%7 : !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} + omp.target map_entries(%map1, %map2, %map3 : !llvm.ptr, !llvm.ptr, !llvm.ptr) { %8 = llvm.load %3 : !llvm.ptr %9 = llvm.load %5 : !llvm.ptr %10 = llvm.add %8, %9 : i32 diff --git a/mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir --- a/mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir @@ -12,7 +12,10 @@ %7 = llvm.alloca %6 x i32 {bindc_name = "c", in_type = i32, operandSegmentSizes = array, uniq_name = "_QFomp_target_regionEc"} : (i64) -> !llvm.ptr llvm.store %1, %3 : !llvm.ptr llvm.store %0, %5 : !llvm.ptr - omp.target map((to -> %3 : !llvm.ptr), (to -> %5 : !llvm.ptr), (from -> %7 : !llvm.ptr)) { + %map1 = omp.map_info var_ptr(%3 : !llvm.ptr) map_clauses(tofrommap_info) capture(ByRef) -> !llvm.ptr {name = ""} + %map2 = omp.map_info var_ptr(%5 : !llvm.ptr) map_clauses(tofrommap_info) capture(ByRef) -> !llvm.ptr {name = ""} + %map3 = omp.map_info var_ptr(%7 : !llvm.ptr) map_clauses(tofrommap_info) capture(ByRef) -> !llvm.ptr {name = ""} + omp.target map_entries( %map1, %map2, %map3 : !llvm.ptr, !llvm.ptr, !llvm.ptr) { omp.parallel { %8 = llvm.load %3 : !llvm.ptr %9 = llvm.load %5 : !llvm.ptr