diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -404,6 +404,19 @@ } } +static mlir::Value +getIfClauseOperand(Fortran::lower::AbstractConverter &converter, + Fortran::lower::StatementContext &stmtCtx, + const Fortran::parser::OmpClause::If *ifClause) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); + auto &expr = std::get(ifClause->v.t); + mlir::Value ifVal = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); + return firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(), + ifVal); +} + static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter, std::size_t loopVarTypeSize) { // OpenMP runtime requires 32-bit or 64-bit loop variables. @@ -547,6 +560,121 @@ } } +static void +createTargetDataOp(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpClauseList &opClauseList, + const llvm::omp::Directive &directive) { + Fortran::lower::StatementContext stmtCtx; + auto currentLocation = converter.getCurrentLocation(); + auto &firOpBuilder = converter.getFirOpBuilder(); + + Value ifClauseOperand, deviceOperand; + UnitAttr nowaitAttr; + SmallVector useDevicePtrOperand, useDeviceAddrOperand, mapOperands; + SmallVector mapTypes; + + auto addMapClause = [&firOpBuilder, ¤tLocation, &converter, + &mapOperands, &mapTypes](const auto &mapClause) { + auto mapType = std::get( + std::get>(mapClause->v.t) + ->t); + llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + switch (mapType) { + case Fortran::parser::OmpMapType::Type::To: + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; + break; + case Fortran::parser::OmpMapType::Type::From: + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + break; + case Fortran::parser::OmpMapType::Type::Tofrom: + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + break; + case Fortran::parser::OmpMapType::Type::Alloc: + case Fortran::parser::OmpMapType::Type::Release: + // alloc and release is the default behavior in the runtime library, + // i.e. if we don't pass any bits alloc/release that is what the + // runtime is going to do. Therefore, we don't need to signal anything + // for these two type modifiers. + break; + case Fortran::parser::OmpMapType::Type::Delete: + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; + } + if (std::get>( + std::get>(mapClause->v.t) + ->t) + .has_value()) { + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; + } + // TODO: Add support MapTypeModifiers close, mapper, present, iterator + auto mapTypeVal = firOpBuilder.getIntegerAttr( + firOpBuilder.getI64Type(), + static_cast< + std::underlying_type_t>( + mapTypeBits)); + + SmallVector mapOperand; + genObjectList(std::get(mapClause->v.t), + converter, mapOperand); + + for (auto mapOp : mapOperand) { + mapOperands.push_back(mapOp); + mapTypes.push_back(mapTypeVal); + } + }; + + for (const Fortran::parser::OmpClause &clause : opClauseList.v) { + if (const auto &ifClause = + std::get_if(&clause.u)) { + ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause); + } else if (const auto &deviceClause = + std::get_if(&clause.u)) { + if (auto deviceModifier = std::get< + std::optional>( + deviceClause->v.t)) { + if (deviceModifier == + Fortran::parser::OmpDeviceClause::DeviceModifier::Ancestor) { + TODO(currentLocation, "OMPD_target Device Modifier Ancestor"); + } + } + if (const auto *deviceExpr = Fortran::semantics::GetExpr( + std::get(deviceClause->v.t))) { + deviceOperand = + fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx)); + } + } else if (std::get_if( + &clause.u)) { + TODO(currentLocation, "OMPD_target Use Device Ptr"); + } else if (std::get_if( + &clause.u)) { + TODO(currentLocation, "OMPD_target Use Device Addr"); + } else if (std::get_if(&clause.u)) { + nowaitAttr = firOpBuilder.getUnitAttr(); + } else if (const auto &mapClause = + std::get_if(&clause.u)) { + addMapClause(mapClause); + } + } + + SmallVector mapTypesAttr(mapTypes.begin(), mapTypes.end()); + + if (directive == llvm::omp::Directive::OMPD_target_data) { + firOpBuilder.create( + currentLocation, ifClauseOperand, deviceOperand, useDevicePtrOperand, + useDeviceAddrOperand, mapOperands, + ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr)); + } else if (directive == llvm::omp::Directive::OMPD_target_enter_data) { + firOpBuilder.create( + currentLocation, ifClauseOperand, deviceOperand, nowaitAttr, + mapOperands, ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr)); + } else if (directive == llvm::omp::Directive::OMPD_target_exit_data) { + firOpBuilder.create( + currentLocation, ifClauseOperand, deviceOperand, nowaitAttr, + mapOperands, ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr)); + } +} + static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPSimpleStandaloneConstruct @@ -554,29 +682,32 @@ const auto &directive = std::get( simpleStandaloneConstruct.t); + auto currentLocation = converter.getCurrentLocation(); + auto &firOpBuilder = converter.getFirOpBuilder(); + const auto &opClauseList = + std::get(simpleStandaloneConstruct.t); + switch (directive.v) { default: break; case llvm::omp::Directive::OMPD_barrier: - converter.getFirOpBuilder().create( - converter.getCurrentLocation()); + firOpBuilder.create(currentLocation); break; case llvm::omp::Directive::OMPD_taskwait: - converter.getFirOpBuilder().create( - converter.getCurrentLocation()); + firOpBuilder.create(currentLocation); break; case llvm::omp::Directive::OMPD_taskyield: - converter.getFirOpBuilder().create( - converter.getCurrentLocation()); + firOpBuilder.create(currentLocation); break; + case llvm::omp::Directive::OMPD_target_data: case llvm::omp::Directive::OMPD_target_enter_data: - TODO(converter.getCurrentLocation(), "OMPD_target_enter_data"); case llvm::omp::Directive::OMPD_target_exit_data: - TODO(converter.getCurrentLocation(), "OMPD_target_exit_data"); + createTargetDataOp(converter, opClauseList, directive.v); + break; case llvm::omp::Directive::OMPD_target_update: - TODO(converter.getCurrentLocation(), "OMPD_target_update"); + TODO(currentLocation, "OMPD_target_update"); case llvm::omp::Directive::OMPD_ordered: - TODO(converter.getCurrentLocation(), "OMPD_ordered"); + TODO(currentLocation, "OMPD_ordered"); } } @@ -669,19 +800,6 @@ return omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind); } -static mlir::Value -getIfClauseOperand(Fortran::lower::AbstractConverter &converter, - Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClause::If *ifClause) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Location currentLocation = converter.getCurrentLocation(); - auto &expr = std::get(ifClause->v.t); - mlir::Value ifVal = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); - return firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(), - ifVal); -} - /* When parallel is used in a combined construct, then use this function to * create the parallel operation. It handles the parallel specific clauses * and leaves the rest for handling at the inner operations. diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -97,6 +97,18 @@ return success(); } }; + +template +struct LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(Op curOp, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(curOp, TypeRange(), adaptor.getOperands(), + curOp.getOperation()->getAttrs()); + return success(); + } +}; } // namespace void mlir::configureOpenMPToLLVMConversionLegality( @@ -109,13 +121,14 @@ typeConverter.isLegal(op->getOperandTypes()) && typeConverter.isLegal(op->getResultTypes()); }); - target - .addDynamicallyLegalOp( - [&](Operation *op) { - return typeConverter.isLegal(op->getOperandTypes()) && - typeConverter.isLegal(op->getResultTypes()); - }); + target.addDynamicallyLegalOp( + [&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); target.addDynamicallyLegalOp([&](Operation *op) { return typeConverter.isLegal(op->getOperandTypes()); }); @@ -132,7 +145,10 @@ RegionLessOpWithVarOperandsConversion, RegionLessOpWithVarOperandsConversion, RegionLessOpWithVarOperandsConversion, - RegionLessOpWithVarOperandsConversion>(converter); + RegionLessOpWithVarOperandsConversion, + LegalizeDataOpForLLVMTranslation, + LegalizeDataOpForLLVMTranslation, + LegalizeDataOpForLLVMTranslation>(converter); } namespace {