diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -25,28 +25,13 @@ #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" -using namespace mlir; +using DeclareTargetCapturePair = + std::pair; -void Fortran::lower::genOpenMPTerminator(fir::FirOpBuilder &builder, - Operation *op, mlir::Location loc) { - if (mlir::isa(op)) - builder.create(loc); - else - builder.create(loc); -} - -int64_t Fortran::lower::getCollapseValue( - const Fortran::parser::OmpClauseList &clauseList) { - for (const auto &clause : clauseList.v) { - if (const auto &collapseClause = - std::get_if(&clause.u)) { - const auto *expr = Fortran::semantics::GetExpr(collapseClause->v); - return Fortran::evaluate::ToInt64(*expr).value(); - } - } - return 1; -} +//===----------------------------------------------------------------------===// +// Common helper functions +//===----------------------------------------------------------------------===// static Fortran::semantics::Symbol * getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) { @@ -64,6 +49,54 @@ return sym; } +static void genObjectList(const Fortran::parser::OmpObjectList &objectList, + Fortran::lower::AbstractConverter &converter, + llvm::SmallVectorImpl &operands) { + auto addOperands = [&](Fortran::lower::SymbolRef sym) { + const mlir::Value variable = converter.getSymbolAddress(sym); + if (variable) { + operands.push_back(variable); + } else { + if (const auto *details = + sym->detailsIf()) { + operands.push_back(converter.getSymbolAddress(details->symbol())); + converter.copySymbolBinding(details->symbol(), sym); + } + } + }; + for (const Fortran::parser::OmpObject &ompObject : objectList.v) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + addOperands(*sym); + } +} + +static void gatherFuncAndVarSyms( + const Fortran::parser::OmpObjectList &objList, + mlir::omp::DeclareTargetCaptureClause clause, + llvm::SmallVectorImpl> + &symbolAndClause) { + for (const Fortran::parser::OmpObject &ompObject : objList.v) { + Fortran::common::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::Designator &designator) { + if (const Fortran::parser::Name *name = + Fortran::semantics::getDesignatorNameIfDataRef( + designator)) { + symbolAndClause.emplace_back(clause, *name->symbol); + } + }, + [&](const Fortran::parser::Name &name) { + symbolAndClause.emplace_back(clause, *name.symbol); + }}, + ompObject.u); + } +} + +//===----------------------------------------------------------------------===// +// DataSharingProcessor +//===----------------------------------------------------------------------===// + class DataSharingProcessor { bool hasLastPrivateOp; mlir::OpBuilder::InsertPoint lastPrivIP; @@ -115,7 +148,7 @@ // MLIR operation to insert the last private update. Step2 adds // dealocation code as well. void processStep1(); - void processStep2(mlir::Operation *op, bool is_loop); + void processStep2(mlir::Operation *op, bool isLoop); }; void DataSharingProcessor::processStep1() { @@ -126,12 +159,12 @@ insertBarrier(); } -void DataSharingProcessor::processStep2(mlir::Operation *op, bool is_loop) { +void DataSharingProcessor::processStep2(mlir::Operation *op, bool isLoop) { insPt = firOpBuilder.saveInsertionPoint(); copyLastPrivatize(op); firOpBuilder.restoreInsertionPoint(insPt); - if (is_loop) { + if (isLoop) { // push deallocs out of the loop firOpBuilder.setInsertionPointAfter(op); insertDeallocs(); @@ -145,7 +178,7 @@ } void DataSharingProcessor::insertDeallocs() { - for (auto sym : privatizedSymbols) + for (const Fortran::semantics::Symbol *sym : privatizedSymbols) if (Fortran::semantics::IsAllocatable(sym->GetUltimate())) { converter.createHostAssociateVarCloneDealloc(*sym); } @@ -203,7 +236,7 @@ } } - for (auto *ps : privatizedSymbols) { + for (const Fortran::semantics::Symbol *ps : privatizedSymbols) { if (ps->has()) TODO(converter.getCurrentLocation(), "Common Block in privatization clause"); @@ -214,7 +247,7 @@ } bool DataSharingProcessor ::needBarrier() { - for (auto sym : privatizedSymbols) { + for (const Fortran::semantics::Symbol *sym : privatizedSymbols) { if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate) && sym->test(Fortran::semantics::Symbol::Flag::OmpLastPrivate)) return true; @@ -241,7 +274,7 @@ for (const Fortran::parser::OmpClause &clause : opClauseList.v) { if (std::get_if(&clause.u)) { // TODO: Add lastprivate support for simd construct - if (mlir::isa(op)) { + if (mlir::isa(op)) { if (&eval == &eval.parentConstruct->getLastNestedEvaluation()) { // For `omp.sections`, lastprivatized variables occur in // lexically final `omp.section` operation. The following FIR @@ -313,7 +346,7 @@ firOpBuilder.restoreInsertionPoint(unstructuredSectionsIP); } } - } else if (mlir::isa(op)) { + } else if (mlir::isa(op)) { mlir::Operation *lastOper = op->getRegion(0).back().getTerminator(); firOpBuilder.setInsertionPoint(lastOper); @@ -358,7 +391,7 @@ converter.collectSymbolSet(eval, defaultSymbols, flag, /*collectSymbols=*/true, /*collectHostAssociatedSymbols=*/true); - for (auto &e : eval.getNestedEvaluations()) { + for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations()) { if (e.hasNestedEvaluations()) converter.collectSymbolSet(e, symbolsInNestedRegions, flag, /*collectSymbols=*/true, @@ -385,7 +418,7 @@ } void DataSharingProcessor::privatize() { - for (auto sym : privatizedSymbols) { + for (const Fortran::semantics::Symbol *sym : privatizedSymbols) { cloneSymbol(sym); copyFirstPrivateSymbol(sym); } @@ -393,12 +426,12 @@ void DataSharingProcessor::copyLastPrivatize(mlir::Operation *op) { insertLastPrivateCompare(op); - for (auto sym : privatizedSymbols) + for (const Fortran::semantics::Symbol *sym : privatizedSymbols) copyLastPrivateSymbol(sym, &lastPrivIP); } void DataSharingProcessor::defaultPrivatize() { - for (auto sym : defaultSymbols) { + for (const Fortran::semantics::Symbol *sym : defaultSymbols) { if (!symbolsInNestedRegions.contains(sym) && !symbolsInParentRegions.contains(sym) && !privatizedSymbols.contains(sym)) { @@ -408,590 +441,516 @@ } } -/// The COMMON block is a global structure. \p commonValue is the base address -/// of the the COMMON block. As the offset from the symbol \p sym, generate the -/// COMMON block member value (commonValue + offset) for the symbol. -/// FIXME: Share the code with `instantiateCommon` in ConvertVariable.cpp. -static mlir::Value -genCommonBlockMember(Fortran::lower::AbstractConverter &converter, - const Fortran::semantics::Symbol &sym, - mlir::Value commonValue) { - auto &firOpBuilder = converter.getFirOpBuilder(); - mlir::Location currentLocation = converter.getCurrentLocation(); - mlir::IntegerType i8Ty = firOpBuilder.getIntegerType(8); - mlir::Type i8Ptr = firOpBuilder.getRefType(i8Ty); - mlir::Type seqTy = firOpBuilder.getRefType(firOpBuilder.getVarLenSeqTy(i8Ty)); - mlir::Value base = - firOpBuilder.createConvert(currentLocation, seqTy, commonValue); - std::size_t byteOffset = sym.GetUltimate().offset(); - mlir::Value offs = firOpBuilder.createIntegerConstant( - currentLocation, firOpBuilder.getIndexType(), byteOffset); - mlir::Value varAddr = firOpBuilder.create( - currentLocation, i8Ptr, base, mlir::ValueRange{offs}); - mlir::Type symType = converter.genType(sym); - return firOpBuilder.createConvert(currentLocation, - firOpBuilder.getRefType(symType), varAddr); -} - -// Get the extended value for \p val by extracting additional variable -// information from \p base. -static fir::ExtendedValue getExtendedValue(fir::ExtendedValue base, - mlir::Value val) { - return base.match( - [&](const fir::MutableBoxValue &box) -> fir::ExtendedValue { - return fir::MutableBoxValue(val, box.nonDeferredLenParams(), {}); - }, - [&](const auto &) -> fir::ExtendedValue { - return fir::substBase(base, val); - }); -} - -static void threadPrivatizeVars(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval) { - auto &firOpBuilder = converter.getFirOpBuilder(); - mlir::Location currentLocation = converter.getCurrentLocation(); - auto insPt = firOpBuilder.saveInsertionPoint(); - firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); +//===----------------------------------------------------------------------===// +// ClauseProcessor +//===----------------------------------------------------------------------===// - // Get the original ThreadprivateOp corresponding to the symbol and use the - // symbol value from that opeartion to create one ThreadprivateOp copy - // operation inside the parallel region. - auto genThreadprivateOp = [&](Fortran::lower::SymbolRef sym) -> mlir::Value { - mlir::Value symOriThreadprivateValue = converter.getSymbolAddress(sym); - mlir::Operation *op = symOriThreadprivateValue.getDefiningOp(); - assert(mlir::isa(op) && - "The threadprivate operation not created"); - mlir::Value symValue = - mlir::dyn_cast(op).getSymAddr(); - return firOpBuilder.create( - currentLocation, symValue.getType(), symValue); - }; +/// Class that handles the processing of OpenMP clauses. +/// +/// Its `process()` methods perform MLIR code generation for their +/// corresponding clause if it is present in the clause list. Otherwise, they +/// will return `false` to signal that the clause was not found. +/// +/// The intended use is of this class is to move clause processing outside of +/// construct processing, since the same clauses can appear attached to +/// different constructs and constructs can be combined, so that code +/// duplication is minimized. +/// +/// Each construct-lowering function only calls the `process()` +/// methods that relate to clauses that can impact the lowering of that +/// construct. +class ClauseProcessor { + using ClauseTy = Fortran::parser::OmpClause; - llvm::SetVector threadprivateSyms; - converter.collectSymbolSet( - eval, threadprivateSyms, - Fortran::semantics::Symbol::Flag::OmpThreadprivate); - std::set threadprivateSymNames; +public: + ClauseProcessor(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpClauseList &clauses) + : converter(converter), clauses(clauses) {} - // For a COMMON block, the ThreadprivateOp is generated for itself instead of - // its members, so only bind the value of the new copied ThreadprivateOp - // inside the parallel region to the common block symbol only once for - // multiple members in one COMMON block. - llvm::SetVector commonSyms; - for (std::size_t i = 0; i < threadprivateSyms.size(); i++) { - auto sym = threadprivateSyms[i]; - mlir::Value symThreadprivateValue; - // The variable may be used more than once, and each reference has one - // symbol with the same name. Only do once for references of one variable. - if (threadprivateSymNames.find(sym->name()) != threadprivateSymNames.end()) - continue; - threadprivateSymNames.insert(sym->name()); - if (const Fortran::semantics::Symbol *common = - Fortran::semantics::FindCommonBlockContaining(sym->GetUltimate())) { - mlir::Value commonThreadprivateValue; - if (commonSyms.contains(common)) { - commonThreadprivateValue = converter.getSymbolAddress(*common); - } else { - commonThreadprivateValue = genThreadprivateOp(*common); - converter.bindSymbol(*common, commonThreadprivateValue); - commonSyms.insert(common); + // 'Unique' clauses: They can appear at most once in the clause list. + bool + processCollapse(mlir::Location currentLocation, + Fortran::lower::pft::Evaluation &eval, + llvm::SmallVectorImpl &lowerBound, + llvm::SmallVectorImpl &upperBound, + llvm::SmallVectorImpl &step, + llvm::SmallVectorImpl &iv, + std::size_t &loopVarTypeSize) const; + bool processDefault() const; + bool processDevice(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const; + bool processDeviceType(mlir::omp::DeclareTargetDeviceType &result) const; + bool processFinal(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const; + bool processHint(mlir::IntegerAttr &result) const; + bool processMergeable(mlir::UnitAttr &result) const; + bool processNowait(mlir::UnitAttr &result) const; + bool processNumThreads(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const; + bool processOrdered(mlir::IntegerAttr &result) const; + bool processPriority(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const; + bool processProcBind(mlir::omp::ClauseProcBindKindAttr &result) const; + bool processSafelen(mlir::IntegerAttr &result) const; + bool processSchedule(mlir::omp::ClauseScheduleKindAttr &valAttr, + mlir::omp::ScheduleModifierAttr &modifierAttr, + mlir::UnitAttr &simdModifierAttr) const; + bool processScheduleChunk(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const; + bool processSimdlen(mlir::IntegerAttr &result) const; + bool processThreadLimit(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const; + bool processUntied(mlir::UnitAttr &result) const; + + // 'Repeatable' clauses: They can appear multiple times in the clause list. + bool + processAllocate(llvm::SmallVectorImpl &allocatorOperands, + llvm::SmallVectorImpl &allocateOperands) const; + bool processCopyin() const; + bool processDepend(llvm::SmallVectorImpl &dependTypeOperands, + llvm::SmallVectorImpl &dependOperands) const; + bool processIf(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const; + bool + processLink(llvm::SmallVectorImpl &result) const; + bool processMap(llvm::SmallVectorImpl &mapOperands, + llvm::SmallVectorImpl &mapTypes) const; + bool processReduction( + mlir::Location currentLocation, + llvm::SmallVectorImpl &reductionVars, + llvm::SmallVectorImpl &reductionDeclSymbols) const; + bool processSectionsReduction(mlir::Location currentLocation) const; + bool processTo(llvm::SmallVectorImpl &result) const; + bool + processUseDeviceAddr(llvm::SmallVectorImpl &operands, + llvm::SmallVectorImpl &useDeviceTypes, + llvm::SmallVectorImpl &useDeviceLocs, + llvm::SmallVectorImpl + &useDeviceSymbols) const; + bool + processUseDevicePtr(llvm::SmallVectorImpl &operands, + llvm::SmallVectorImpl &useDeviceTypes, + llvm::SmallVectorImpl &useDeviceLocs, + llvm::SmallVectorImpl + &useDeviceSymbols) const; + +private: + using ClauseIterator = std::list::const_iterator; + + /// Utility to find a clause within a range in the clause list. + template + static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end) { + for (ClauseIterator it = begin; it != end; ++it) + if (std::get_if(&it->u)) + return it; + + return end; + } + + /// Return the first instance of the given clause found in the clause list or + /// `nullptr` if not present. If more than one instance is expected, use + /// `findRepeatableClause` instead. + template + const T * + findUniqueClause(const Fortran::parser::CharBlock **source = nullptr) const { + ClauseIterator it = findClause(clauses.v.begin(), clauses.v.end()); + if (it != clauses.v.end()) { + if (source) + *source = &it->source; + return &std::get(it->u); + } + return nullptr; + } + + /// Call `callbackFn` for each occurrence of the given clause. Return `true` + /// if at least one instance was found. + template + bool findRepeatableClause( + std::function + callbackFn) const { + bool found = false; + ClauseIterator nextIt, endIt = clauses.v.end(); + for (ClauseIterator it = clauses.v.begin(); it != endIt; it = nextIt) { + nextIt = findClause(it, endIt); + + if (nextIt != endIt) { + callbackFn(&std::get(nextIt->u), nextIt->source); + found = true; + ++nextIt; } - symThreadprivateValue = - genCommonBlockMember(converter, *sym, commonThreadprivateValue); - } else { - symThreadprivateValue = genThreadprivateOp(*sym); } + return found; + } - fir::ExtendedValue sexv = converter.getSymbolExtendedValue(*sym); - fir::ExtendedValue symThreadprivateExv = - getExtendedValue(sexv, symThreadprivateValue); - converter.bindSymbol(*sym, symThreadprivateExv); + /// Set the `result` to a new `mlir::UnitAttr` if the clause is present. + template + bool markClauseOccurrence(mlir::UnitAttr &result) const { + if (findUniqueClause()) { + result = converter.getFirOpBuilder().getUnitAttr(); + return true; + } + return false; } - firOpBuilder.restoreInsertionPoint(insPt); + Fortran::lower::AbstractConverter &converter; + const Fortran::parser::OmpClauseList &clauses; +}; + +//===----------------------------------------------------------------------===// +// ClauseProcessor helper functions +//===----------------------------------------------------------------------===// + +/// Check for unsupported map operand types. +static void checkMapType(mlir::Location location, mlir::Type type) { + if (auto refType = type.dyn_cast()) + type = refType.getElementType(); + if (auto boxType = type.dyn_cast_or_null()) + if (!boxType.getElementType().isa()) + TODO(location, "OMPD_target_data MapOperand BoxType"); } -static void -genCopyinClause(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpClauseList &opClauseList) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint(); - firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); - bool hasCopyin = false; - for (const Fortran::parser::OmpClause &clause : opClauseList.v) { - if (const auto ©inClause = - std::get_if(&clause.u)) { - hasCopyin = true; - const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v; - for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); - if (sym->has()) - TODO(converter.getCurrentLocation(), "common block in Copyin clause"); - if (Fortran::semantics::IsAllocatableOrPointer(sym->GetUltimate())) - TODO(converter.getCurrentLocation(), - "pointer or allocatable variables in Copyin clause"); - assert(sym->has() && - "No host-association found"); - converter.copyHostAssociateVar(*sym); - } - } - } - // [OMP 5.0, 2.19.6.1] The copy is done after the team is formed and prior to - // the execution of the associated structured block. Emit implicit barrier to - // synchronize threads and avoid data races on propagation master's thread - // values of threadprivate variables to local instances of that variables of - // all other implicit threads. - if (hasCopyin) - firOpBuilder.create(converter.getCurrentLocation()); - firOpBuilder.restoreInsertionPoint(insPt); +static std::string getReductionName(llvm::StringRef name, mlir::Type ty) { + return (llvm::Twine(name) + + (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) + + llvm::Twine(ty.getIntOrFloatBitWidth())) + .str(); } -static void genObjectList(const Fortran::parser::OmpObjectList &objectList, - Fortran::lower::AbstractConverter &converter, - llvm::SmallVectorImpl &operands) { - auto addOperands = [&](Fortran::lower::SymbolRef sym) { - const mlir::Value variable = converter.getSymbolAddress(sym); - if (variable) { - operands.push_back(variable); - } else { - if (const auto *details = - sym->detailsIf()) { - operands.push_back(converter.getSymbolAddress(details->symbol())); - converter.copySymbolBinding(details->symbol(), sym); - } - } - }; - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); - addOperands(*sym); +static std::string getReductionName( + Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, + mlir::Type ty) { + std::string reductionName; + + switch (intrinsicOp) { + case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + reductionName = "add_reduction"; + break; + case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + reductionName = "multiply_reduction"; + break; + case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + return "and_reduction"; + case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + return "eqv_reduction"; + case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + return "or_reduction"; + case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + return "neqv_reduction"; + default: + reductionName = "other_reduction"; + break; } -} -static mlir::Value -getIfClauseOperand(Fortran::lower::AbstractConverter &converter, - Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClause::If *ifClause, - mlir::Location clauseLocation) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - auto &expr = std::get(ifClause->v.t); - mlir::Value ifVal = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); - return firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(), - ifVal); + return getReductionName(reductionName, ty); } -static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter, - std::size_t loopVarTypeSize) { - // OpenMP runtime requires 32-bit or 64-bit loop variables. - loopVarTypeSize = loopVarTypeSize * 8; - if (loopVarTypeSize < 32) { - loopVarTypeSize = 32; - } else if (loopVarTypeSize > 64) { - loopVarTypeSize = 64; - mlir::emitWarning(converter.getCurrentLocation(), - "OpenMP loop iteration variable cannot have more than 64 " - "bits size and will be narrowed into 64 bits."); - } - assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) && - "OpenMP loop iteration variable size must be transformed into 32-bit " - "or 64-bit"); - return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize); +/// This function returns the identity value of the operator \p reductionOpName. +/// For example: +/// 0 + x = x, +/// 1 * x = x +static int getOperationIdentity(llvm::StringRef reductionOpName, + mlir::Location loc) { + if (reductionOpName.contains("add") || reductionOpName.contains("or") || + reductionOpName.contains("neqv")) + return 0; + if (reductionOpName.contains("multiply") || reductionOpName.contains("and") || + reductionOpName.contains("eqv")) + return 1; + TODO(loc, "Reduction of some intrinsic operators is not supported"); } -/// Create empty blocks for the current region. -/// These blocks replace blocks parented to an enclosing region. -void createEmptyRegionBlocks( - fir::FirOpBuilder &firOpBuilder, - std::list &evaluationList) { - auto *region = &firOpBuilder.getRegion(); - for (auto &eval : evaluationList) { - if (eval.block) { - if (eval.block->empty()) { - eval.block->erase(); - eval.block = firOpBuilder.createBlock(region); - } else { - [[maybe_unused]] auto &terminatorOp = eval.block->back(); - assert((mlir::isa(terminatorOp) || - mlir::isa(terminatorOp)) && - "expected terminator op"); - } - } - if (!eval.isDirective() && eval.hasNestedEvaluations()) - createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations()); - } -} - -void resetBeforeTerminator(fir::FirOpBuilder &firOpBuilder, - mlir::Operation *storeOp, mlir::Block &block) { - if (storeOp) - firOpBuilder.setInsertionPointAfter(storeOp); - else - firOpBuilder.setInsertionPointToStart(&block); -} - -/// Create the body (block) for an OpenMP Operation. -/// -/// \param [in] op - the operation the body belongs to. -/// \param [inout] converter - converter to use for the clauses. -/// \param [in] loc - location in source code. -/// \param [in] eval - current PFT node/evaluation. -/// \oaran [in] clauses - list of clauses to process. -/// \param [in] args - block arguments (induction variable[s]) for the -//// region. -/// \param [in] outerCombined - is this an outer operation - prevents -/// privatization. -template -static void -createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter, - mlir::Location &loc, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList *clauses = nullptr, - const SmallVector &args = {}, - bool outerCombined = false, - DataSharingProcessor *dsp = nullptr) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - // If an argument for the region is provided then create the block with that - // argument. Also update the symbol's address with the mlir argument value. - // e.g. For loops the argument is the induction variable. And all further - // uses of the induction variable should use this mlir value. - mlir::Operation *storeOp = nullptr; - if (args.size()) { - std::size_t loopVarTypeSize = 0; - for (const Fortran::semantics::Symbol *arg : args) - loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size()); - mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); - SmallVector tiv; - SmallVector locs; - for (int i = 0; i < (int)args.size(); i++) { - tiv.push_back(loopVarType); - locs.push_back(loc); +static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type, + llvm::StringRef reductionOpName, + fir::FirOpBuilder &builder) { + assert((fir::isa_integer(type) || fir::isa_real(type) || + type.isa()) && + "only integer, logical and real types are currently supported"); + if (reductionOpName.contains("max")) { + if (auto ty = type.dyn_cast()) { + const llvm::fltSemantics &sem = ty.getFloatSemantics(); + return builder.createRealConstant( + loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); } - firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs); - int argIndex = 0; - // The argument is not currently in memory, so make a temporary for the - // argument, and store it there, then bind that location to the argument. - for (const Fortran::semantics::Symbol *arg : args) { - mlir::Value val = - fir::getBase(op.getRegion().front().getArgument(argIndex)); - mlir::Value temp = firOpBuilder.createTemporary( - loc, loopVarType, - llvm::ArrayRef{ - Fortran::lower::getAdaptToByRefAttr(firOpBuilder)}); - storeOp = firOpBuilder.create(loc, val, temp); - converter.bindSymbol(*arg, temp); - argIndex++; + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, minInt); + } else if (reductionOpName.contains("min")) { + if (auto ty = type.dyn_cast()) { + const llvm::fltSemantics &sem = ty.getFloatSemantics(); + return builder.createRealConstant( + loc, type, llvm::APFloat::getSmallest(sem, /*Negative=*/true)); } + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, maxInt); + } else if (reductionOpName.contains("ior")) { + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, zeroInt); + } else if (reductionOpName.contains("ieor")) { + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, zeroInt); + } else if (reductionOpName.contains("iand")) { + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, allOnInt); } else { - firOpBuilder.createBlock(&op.getRegion()); - } - // Set the insert for the terminator operation to go at the end of the - // block - this is either empty or the block with the stores above, - // the end of the block works for both. - mlir::Block &block = op.getRegion().back(); - firOpBuilder.setInsertionPointToEnd(&block); + if (type.isa()) + return builder.create( + loc, type, + builder.getFloatAttr( + type, (double)getOperationIdentity(reductionOpName, loc))); - // If it is an unstructured region and is not the outer region of a combined - // construct, create empty blocks for all evaluations. - if (eval.lowerAsUnstructured() && !outerCombined) - createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations()); + if (type.isa()) { + mlir::Value intConst = builder.create( + loc, builder.getI1Type(), + builder.getIntegerAttr(builder.getI1Type(), + getOperationIdentity(reductionOpName, loc))); + return builder.createConvert(loc, type, intConst); + } - // Insert the terminator. - if constexpr (std::is_same_v || - std::is_same_v) { - mlir::ValueRange results; - firOpBuilder.create(loc, results); - } else { - firOpBuilder.create(loc); + return builder.create( + loc, type, + builder.getIntegerAttr(type, + getOperationIdentity(reductionOpName, loc))); } - // Reset the insert point to before the terminator. - resetBeforeTerminator(firOpBuilder, storeOp, block); +} - // Handle privatization. Do not privatize if this is the outer operation. - if (clauses && !outerCombined) { - constexpr bool is_loop = std::is_same_v || - std::is_same_v; - if (!dsp) { - DataSharingProcessor proc(converter, *clauses, eval); - proc.processStep1(); - proc.processStep2(op, is_loop); - } else { - dsp->processStep2(op, is_loop); - } +template +static mlir::Value getReductionOperation(fir::FirOpBuilder &builder, + mlir::Type type, mlir::Location loc, + mlir::Value op1, mlir::Value op2) { + assert(type.isIntOrIndexOrFloat() && + "only integer and float types are currently supported"); + if (type.isIntOrIndex()) + return builder.create(loc, op1, op2); + return builder.create(loc, op1, op2); +} - if (storeOp) - firOpBuilder.setInsertionPointAfter(storeOp); - } +static mlir::omp::ReductionDeclareOp +createMinimalReductionDecl(fir::FirOpBuilder &builder, + llvm::StringRef reductionOpName, mlir::Type type, + mlir::Location loc) { + mlir::ModuleOp module = builder.getModule(); + mlir::OpBuilder modBuilder(module.getBodyRegion()); - if constexpr (std::is_same_v) { - threadPrivatizeVars(converter, eval); - if (clauses) - genCopyinClause(converter, *clauses); - } + mlir::omp::ReductionDeclareOp decl = + modBuilder.create(loc, reductionOpName, + type); + builder.createBlock(&decl.getInitializerRegion(), + decl.getInitializerRegion().end(), {type}, {loc}); + builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); + mlir::Value init = getReductionInitValue(loc, type, reductionOpName, builder); + builder.create(loc, init); + + builder.createBlock(&decl.getReductionRegion(), + decl.getReductionRegion().end(), {type, type}, + {loc, loc}); + + return decl; } -static void createBodyOfTargetOp( - Fortran::lower::AbstractConverter &converter, mlir::omp::DataOp &dataOp, - const llvm::SmallVector &useDeviceTypes, - const llvm::SmallVector &useDeviceLocs, - const SmallVector &useDeviceSymbols, - const mlir::Location ¤tLocation) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Region ®ion = dataOp.getRegion(); +/// Creates an OpenMP reduction declaration and inserts it into the provided +/// symbol table. The declaration has a constant initializer with the neutral +/// value `initValue`, and the reduction combiner carried over from `reduce`. +/// TODO: Generalize this for non-integer types, add atomic region. +static mlir::omp::ReductionDeclareOp +createReductionDecl(fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, + const Fortran::parser::ProcedureDesignator &procDesignator, + mlir::Type type, mlir::Location loc) { + mlir::OpBuilder::InsertionGuard guard(builder); + mlir::ModuleOp module = builder.getModule(); - firOpBuilder.createBlock(®ion, {}, useDeviceTypes, useDeviceLocs); - firOpBuilder.create(currentLocation); - firOpBuilder.setInsertionPointToStart(®ion.front()); + auto decl = + module.lookupSymbol(reductionOpName); + if (decl) + return decl; - unsigned argIndex = 0; - for (auto *sym : useDeviceSymbols) { - const mlir::BlockArgument &arg = region.front().getArgument(argIndex); - mlir::Value val = fir::getBase(arg); - fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym); - if (auto refType = val.getType().dyn_cast()) { - if (fir::isa_builtin_cptr_type(refType.getElementType())) { - converter.bindSymbol(*sym, val); - } else { - extVal.match( - [&](const fir::MutableBoxValue &mbv) { - converter.bindSymbol( - *sym, - fir::MutableBoxValue( - val, fir::factory::getNonDeferredLenParams(extVal), {})); - }, - [&](const auto &) { - TODO(converter.getCurrentLocation(), - "use_device clause operand unsupported type"); - }); - } + decl = createMinimalReductionDecl(builder, reductionOpName, type, loc); + builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); + mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); + mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); + + mlir::Value reductionOp; + if (const auto *name{ + Fortran::parser::Unwrap(procDesignator)}) { + if (name->source == "max") { + reductionOp = + getReductionOperation( + builder, type, loc, op1, op2); + } else if (name->source == "min") { + reductionOp = + getReductionOperation( + builder, type, loc, op1, op2); + } else if (name->source == "ior") { + assert((type.isIntOrIndex()) && "only integer is expected"); + reductionOp = builder.create(loc, op1, op2); + } else if (name->source == "ieor") { + assert((type.isIntOrIndex()) && "only integer is expected"); + reductionOp = builder.create(loc, op1, op2); + } else if (name->source == "iand") { + assert((type.isIntOrIndex()) && "only integer is expected"); + reductionOp = builder.create(loc, op1, op2); } else { - TODO(converter.getCurrentLocation(), - "use_device clause operand unsupported type"); + TODO(loc, "Reduction of some intrinsic operators is not supported"); } - argIndex++; } + + builder.create(loc, reductionOp); + return decl; } -static void createTargetOp(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpClauseList &opClauseList, - const llvm::omp::Directive &directive, - mlir::Location currentLocation, - Fortran::lower::pft::Evaluation *eval = nullptr) { - Fortran::lower::StatementContext stmtCtx; - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); +/// Creates an OpenMP reduction declaration and inserts it into the provided +/// symbol table. The declaration has a constant initializer with the neutral +/// value `initValue`, and the reduction combiner carried over from `reduce`. +/// TODO: Generalize this for non-integer types, add atomic region. +static mlir::omp::ReductionDeclareOp createReductionDecl( + fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, + Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, + mlir::Type type, mlir::Location loc) { + mlir::OpBuilder::InsertionGuard guard(builder); + mlir::ModuleOp module = builder.getModule(); - mlir::Value ifClauseOperand, deviceOperand, threadLmtOperand; - mlir::UnitAttr nowaitAttr; - llvm::SmallVector mapOperands, devicePtrOperands, - deviceAddrOperands; - llvm::SmallVector mapTypes; - llvm::SmallVector useDeviceTypes; - llvm::SmallVector useDeviceLocs; - SmallVector useDeviceSymbols; - - /// Check for unsupported map operand types. - auto checkType = [](mlir::Location location, mlir::Type type) { - if (auto refType = type.dyn_cast()) - type = refType.getElementType(); - if (auto boxType = type.dyn_cast_or_null()) - if (!boxType.getElementType().isa()) - TODO(location, "OMPD_target_data MapOperand BoxType"); - }; + auto decl = + module.lookupSymbol(reductionOpName); + if (decl) + return decl; - auto addMapClause = [&](const auto &mapClause, mlir::Location &location) { - const auto &oMapType = - std::get>(mapClause->v.t); - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; - // If the map type is specified, then process it else Tofrom is the default. - if (oMapType) { - const Fortran::parser::OmpMapType::Type &mapType = - std::get(oMapType->t); - switch (mapType) { - case Fortran::parser::OmpMapType::Type::To: - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; - break; - case Fortran::parser::OmpMapType::Type::From: - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; - break; - case Fortran::parser::OmpMapType::Type::Tofrom: - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; - break; - case Fortran::parser::OmpMapType::Type::Alloc: - case Fortran::parser::OmpMapType::Type::Release: - // alloc and release is the default map_type for the Target Data Ops, - // i.e. if no bits for map_type is supplied then alloc/release is - // implicitly assumed based on the target directive. Default value for - // Target Data and Enter Data is alloc and for Exit Data it is release. - break; - case Fortran::parser::OmpMapType::Type::Delete: - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; - } + decl = createMinimalReductionDecl(builder, reductionOpName, type, loc); + builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); + mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); + mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); - if (std::get>( - oMapType->t)) - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; - } else { - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; - } + mlir::Value reductionOp; + switch (intrinsicOp) { + case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + reductionOp = + getReductionOperation( + builder, type, loc, op1, op2); + break; + case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + reductionOp = + getReductionOperation( + builder, type, loc, op1, op2); + break; + case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: { + mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); + mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); - // TODO: Add support MapTypeModifiers close, mapper, present, iterator + mlir::Value andiOp = builder.create(loc, op1I1, op2I1); - mlir::IntegerAttr mapTypeAttr = firOpBuilder.getIntegerAttr( - firOpBuilder.getI64Type(), - static_cast< - std::underlying_type_t>( - mapTypeBits)); + reductionOp = builder.createConvert(loc, type, andiOp); + break; + } + case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: { + mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); + mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); - llvm::SmallVector mapOperand; - /// Check for unsupported map operand types. - for (const Fortran::parser::OmpObject &ompObject : - std::get(mapClause->v.t).v) { - if (Fortran::parser::Unwrap(ompObject) || - Fortran::parser::Unwrap( - ompObject)) - TODO(location, - "OMPD_target_data for Array Expressions or Structure Components"); - } - genObjectList(std::get(mapClause->v.t), - converter, mapOperand); + mlir::Value oriOp = builder.create(loc, op1I1, op2I1); - for (mlir::Value mapOp : mapOperand) { - checkType(mapOp.getLoc(), mapOp.getType()); - mapOperands.push_back(mapOp); - mapTypes.push_back(mapTypeAttr); - } - }; + reductionOp = builder.createConvert(loc, type, oriOp); + break; + } + case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: { + mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); + mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); - auto addUseDeviceClause = [&](const auto &useDeviceClause, auto &operands) { - genObjectList(useDeviceClause, converter, operands); - for (auto &operand : operands) { - checkType(operand.getLoc(), operand.getType()); - useDeviceTypes.push_back(operand.getType()); - useDeviceLocs.push_back(operand.getLoc()); - } - for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); - useDeviceSymbols.push_back(sym); - } - }; + mlir::Value cmpiOp = builder.create( + loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); - for (const Fortran::parser::OmpClause &clause : opClauseList.v) { - mlir::Location clauseLocation = converter.genLocation(clause.source); - if (const auto &ifClause = - std::get_if(&clause.u)) { - ifClauseOperand = - getIfClauseOperand(converter, stmtCtx, ifClause, clauseLocation); - } else if (const auto &deviceClause = - std::get_if(&clause.u)) { - if (auto deviceModifier = std::get< - std::optional>( - deviceClause->v.t)) { - if (deviceModifier == - Fortran::parser::OmpDeviceClause::DeviceModifier::Ancestor) { - TODO(clauseLocation, "OMPD_target Device Modifier Ancestor"); - } - } - if (const auto *deviceExpr = Fortran::semantics::GetExpr( - std::get(deviceClause->v.t))) { - deviceOperand = - fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx)); - } - } else if (const auto &threadLmtClause = - std::get_if( - &clause.u)) { - threadLmtOperand = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(threadLmtClause->v), stmtCtx)); - } else if (std::get_if(&clause.u)) { - nowaitAttr = firOpBuilder.getUnitAttr(); - } else if (const auto &devPtrClause = - std::get_if( - &clause.u)) { - addUseDeviceClause(devPtrClause->v, devicePtrOperands); - } else if (const auto &devAddrClause = - std::get_if( - &clause.u)) { - addUseDeviceClause(devAddrClause->v, deviceAddrOperands); - } else if (const auto &mapClause = - std::get_if(&clause.u)) { - addMapClause(mapClause, clauseLocation); - } else { - TODO(clauseLocation, "OMPD_target unhandled clause"); - } + reductionOp = builder.createConvert(loc, type, cmpiOp); + break; } + case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: { + mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); + mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); - llvm::SmallVector mapTypesAttr(mapTypes.begin(), - mapTypes.end()); - mlir::ArrayAttr mapTypesArrayAttr = - ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr); + mlir::Value cmpiOp = builder.create( + loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); - if (directive == llvm::omp::Directive::OMPD_target) { - auto targetOp = firOpBuilder.create( - currentLocation, ifClauseOperand, deviceOperand, threadLmtOperand, - nowaitAttr, mapOperands, mapTypesArrayAttr); - createBodyOfOp(targetOp, converter, currentLocation, *eval, &opClauseList); - } else if (directive == llvm::omp::Directive::OMPD_target_data) { - auto dataOp = firOpBuilder.create( - currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands, - deviceAddrOperands, mapOperands, mapTypesArrayAttr); - createBodyOfTargetOp(converter, dataOp, useDeviceTypes, useDeviceLocs, - useDeviceSymbols, currentLocation); - } else if (directive == llvm::omp::Directive::OMPD_target_enter_data) { - firOpBuilder.create(currentLocation, ifClauseOperand, - deviceOperand, nowaitAttr, - mapOperands, mapTypesArrayAttr); - } else if (directive == llvm::omp::Directive::OMPD_target_exit_data) { - firOpBuilder.create(currentLocation, ifClauseOperand, - deviceOperand, nowaitAttr, mapOperands, - mapTypesArrayAttr); - } else { - TODO(currentLocation, "OMPD_target directive unknown"); + reductionOp = builder.createConvert(loc, type, cmpiOp); + break; } + default: + TODO(loc, "Reduction of some intrinsic operators is not supported"); + } + + builder.create(loc, reductionOp); + return decl; } -static void genOMP(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPSimpleStandaloneConstruct - &simpleStandaloneConstruct) { - const auto &directive = - std::get( - simpleStandaloneConstruct.t); - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - const Fortran::parser::OmpClauseList &opClauseList = - std::get(simpleStandaloneConstruct.t); - mlir::Location currentLocation = converter.genLocation(directive.source); +static mlir::omp::ScheduleModifier +translateScheduleModifier(const Fortran::parser::OmpScheduleModifierType &m) { + switch (m.v) { + case Fortran::parser::OmpScheduleModifierType::ModType::Monotonic: + return mlir::omp::ScheduleModifier::monotonic; + case Fortran::parser::OmpScheduleModifierType::ModType::Nonmonotonic: + return mlir::omp::ScheduleModifier::nonmonotonic; + case Fortran::parser::OmpScheduleModifierType::ModType::Simd: + return mlir::omp::ScheduleModifier::simd; + } + return mlir::omp::ScheduleModifier::none; +} - switch (directive.v) { - default: - break; - case llvm::omp::Directive::OMPD_barrier: - firOpBuilder.create(currentLocation); - break; - case llvm::omp::Directive::OMPD_taskwait: - firOpBuilder.create(currentLocation); - break; - case llvm::omp::Directive::OMPD_taskyield: - firOpBuilder.create(currentLocation); - break; - case llvm::omp::Directive::OMPD_target_data: - case llvm::omp::Directive::OMPD_target_enter_data: - case llvm::omp::Directive::OMPD_target_exit_data: - createTargetOp(converter, opClauseList, directive.v, currentLocation); - break; - case llvm::omp::Directive::OMPD_target_update: - TODO(currentLocation, "OMPD_target_update"); - case llvm::omp::Directive::OMPD_ordered: - TODO(currentLocation, "OMPD_ordered"); +static mlir::omp::ScheduleModifier +getScheduleModifier(const Fortran::parser::OmpScheduleClause &x) { + const auto &modifier = + std::get>(x.t); + // The input may have the modifier any order, so we look for one that isn't + // SIMD. If modifier is not set at all, fall down to the bottom and return + // "none". + if (modifier) { + const auto &modType1 = + std::get(modifier->t); + if (modType1.v.v == + Fortran::parser::OmpScheduleModifierType::ModType::Simd) { + const auto &modType2 = std::get< + std::optional>( + modifier->t); + if (modType2 && + modType2->v.v != + Fortran::parser::OmpScheduleModifierType::ModType::Simd) + return translateScheduleModifier(modType2->v); + + return mlir::omp::ScheduleModifier::none; + } + + return translateScheduleModifier(modType1.v); } + return mlir::omp::ScheduleModifier::none; +} + +static mlir::omp::ScheduleModifier +getSimdModifier(const Fortran::parser::OmpScheduleClause &x) { + const auto &modifier = + std::get>(x.t); + // Either of the two possible modifiers in the input can be the SIMD modifier, + // so look in either one, and return simd if we find one. Not found = return + // "none". + if (modifier) { + const auto &modType1 = + std::get(modifier->t); + if (modType1.v.v == Fortran::parser::OmpScheduleModifierType::ModType::Simd) + return mlir::omp::ScheduleModifier::simd; + + const auto &modType2 = std::get< + std::optional>( + modifier->t); + if (modType2 && modType2->v.v == + Fortran::parser::OmpScheduleModifierType::ModType::Simd) + return mlir::omp::ScheduleModifier::simd; + } + return mlir::omp::ScheduleModifier::none; } static void genAllocateClause(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OmpAllocateClause &ompAllocateClause, - SmallVector &allocatorOperands, - SmallVector &allocateOperands) { - auto &firOpBuilder = converter.getFirOpBuilder(); - auto currentLocation = converter.getCurrentLocation(); + llvm::SmallVectorImpl &allocatorOperands, + llvm::SmallVectorImpl &allocateOperands) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); Fortran::lower::StatementContext stmtCtx; mlir::Value allocatorOperand; @@ -1010,7 +969,7 @@ allocateModifier->u); if (allocateModifier && !onlyAllocator) { - TODO(converter.getCurrentLocation(), "OmpAllocateClause ALIGN modifier"); + TODO(currentLocation, "OmpAllocateClause ALIGN modifier"); } // Check if allocate clause has allocator specified. If so, add it @@ -1022,592 +981,1293 @@ allocateModifier->u); allocatorOperand = fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(allocatorValue.v), stmtCtx)); - allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), - allocatorOperand); - } else { - allocatorOperand = firOpBuilder.createIntegerConstant( - currentLocation, firOpBuilder.getI32Type(), 1); - allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), - allocatorOperand); - } - genObjectList(ompObjectList, converter, allocateOperands); -} - -static void -genOMP(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) { - std::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct - &simpleStandaloneConstruct) { - genOMP(converter, eval, simpleStandaloneConstruct); - }, - [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) { - SmallVector operandRange; - if (const auto &ompObjectList = - std::get>( - flushConstruct.t)) - genObjectList(*ompObjectList, converter, operandRange); - const auto &memOrderClause = std::get>>( - flushConstruct.t); - if (memOrderClause.has_value() && memOrderClause->size() > 0) - TODO(converter.getCurrentLocation(), - "Handle OmpMemoryOrderClause"); - converter.getFirOpBuilder().create( - converter.getCurrentLocation(), operandRange); - }, - [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); - }, - [&](const Fortran::parser::OpenMPCancellationPointConstruct - &cancellationPointConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); - }, - }, - standaloneConstruct.u); + allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), + allocatorOperand); + } else { + allocatorOperand = firOpBuilder.createIntegerConstant( + currentLocation, firOpBuilder.getI32Type(), 1); + allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), + allocatorOperand); + } + genObjectList(ompObjectList, converter, allocateOperands); } -static omp::ClauseProcBindKindAttr genProcBindKindAttr( +static mlir::omp::ClauseProcBindKindAttr genProcBindKindAttr( fir::FirOpBuilder &firOpBuilder, const Fortran::parser::OmpClause::ProcBind *procBindClause) { - omp::ClauseProcBindKind pbKind; + mlir::omp::ClauseProcBindKind procBindKind; switch (procBindClause->v.v) { case Fortran::parser::OmpProcBindClause::Type::Master: - pbKind = omp::ClauseProcBindKind::Master; + procBindKind = mlir::omp::ClauseProcBindKind::Master; break; case Fortran::parser::OmpProcBindClause::Type::Close: - pbKind = omp::ClauseProcBindKind::Close; + procBindKind = mlir::omp::ClauseProcBindKind::Close; break; case Fortran::parser::OmpProcBindClause::Type::Spread: - pbKind = omp::ClauseProcBindKind::Spread; + procBindKind = mlir::omp::ClauseProcBindKind::Spread; break; case Fortran::parser::OmpProcBindClause::Type::Primary: - pbKind = omp::ClauseProcBindKind::Primary; + procBindKind = mlir::omp::ClauseProcBindKind::Primary; break; } - return omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind); + return mlir::omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), + procBindKind); } -static omp::ClauseTaskDependAttr +static mlir::omp::ClauseTaskDependAttr genDependKindAttr(fir::FirOpBuilder &firOpBuilder, const Fortran::parser::OmpClause::Depend *dependClause) { - omp::ClauseTaskDepend pbKind; + mlir::omp::ClauseTaskDepend pbKind; switch ( std::get( std::get(dependClause->v.u) .t) .v) { case Fortran::parser::OmpDependenceType::Type::In: - pbKind = omp::ClauseTaskDepend::taskdependin; + pbKind = mlir::omp::ClauseTaskDepend::taskdependin; break; case Fortran::parser::OmpDependenceType::Type::Out: - pbKind = omp::ClauseTaskDepend::taskdependout; + pbKind = mlir::omp::ClauseTaskDepend::taskdependout; break; case Fortran::parser::OmpDependenceType::Type::Inout: - pbKind = omp::ClauseTaskDepend::taskdependinout; + pbKind = mlir::omp::ClauseTaskDepend::taskdependinout; break; default: llvm_unreachable("unknown parser task dependence type"); break; } - return omp::ClauseTaskDependAttr::get(firOpBuilder.getContext(), pbKind); + return mlir::omp::ClauseTaskDependAttr::get(firOpBuilder.getContext(), + pbKind); +} + +static mlir::Value +getIfClauseOperand(Fortran::lower::AbstractConverter &converter, + Fortran::lower::StatementContext &stmtCtx, + const Fortran::parser::OmpClause::If *ifClause, + mlir::Location clauseLocation) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + auto &expr = std::get(ifClause->v.t); + mlir::Value ifVal = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); + return firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(), + ifVal); +} + +/// Creates a reduction declaration and associates it with an OpenMP block +/// directive. +static void +addReductionDecl(mlir::Location currentLocation, + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpReductionClause &reduction, + llvm::SmallVectorImpl &reductionVars, + llvm::SmallVectorImpl &reductionDeclSymbols) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::omp::ReductionDeclareOp decl; + const auto &redOperator{ + std::get(reduction.t)}; + const auto &objectList{std::get(reduction.t)}; + if (const auto &redDefinedOp = + std::get_if(&redOperator.u)) { + const auto &intrinsicOp{ + std::get( + redDefinedOp->u)}; + switch (intrinsicOp) { + case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + break; + + default: + TODO(currentLocation, + "Reduction of some intrinsic operators is not supported"); + break; + } + for (const Fortran::parser::OmpObject &ompObject : objectList.v) { + if (const auto *name{ + Fortran::parser::Unwrap(ompObject)}) { + if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + mlir::Value symVal = converter.getSymbolAddress(*symbol); + mlir::Type redType = + symVal.getType().cast().getEleTy(); + reductionVars.push_back(symVal); + if (redType.isa()) + decl = createReductionDecl( + firOpBuilder, + getReductionName(intrinsicOp, firOpBuilder.getI1Type()), + intrinsicOp, redType, currentLocation); + else if (redType.isIntOrIndexOrFloat()) { + decl = createReductionDecl(firOpBuilder, + getReductionName(intrinsicOp, redType), + intrinsicOp, redType, currentLocation); + } else { + TODO(currentLocation, "Reduction of some types is not supported"); + } + reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( + firOpBuilder.getContext(), decl.getSymName())); + } + } + } + } else if (const auto *reductionIntrinsic = + std::get_if( + &redOperator.u)) { + if (const auto *name{Fortran::parser::Unwrap( + reductionIntrinsic)}) { + if ((name->source != "max") && (name->source != "min") && + (name->source != "ior") && (name->source != "ieor") && + (name->source != "iand")) { + TODO(currentLocation, + "Reduction of intrinsic procedures is not supported"); + } + std::string intrinsicOp = name->ToString(); + for (const Fortran::parser::OmpObject &ompObject : objectList.v) { + if (const auto *name{ + Fortran::parser::Unwrap(ompObject)}) { + if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + mlir::Value symVal = converter.getSymbolAddress(*symbol); + mlir::Type redType = + symVal.getType().cast().getEleTy(); + reductionVars.push_back(symVal); + assert(redType.isIntOrIndexOrFloat() && + "Unsupported reduction type"); + decl = createReductionDecl( + firOpBuilder, getReductionName(intrinsicOp, redType), + *reductionIntrinsic, redType, currentLocation); + reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( + firOpBuilder.getContext(), decl.getSymName())); + } + } + } + } + } +} + +static void +addUseDeviceClause(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpObjectList &useDeviceClause, + llvm::SmallVectorImpl &operands, + llvm::SmallVectorImpl &useDeviceTypes, + llvm::SmallVectorImpl &useDeviceLocs, + llvm::SmallVectorImpl + &useDeviceSymbols) { + genObjectList(useDeviceClause, converter, operands); + for (mlir::Value &operand : operands) { + checkMapType(operand.getLoc(), operand.getType()); + useDeviceTypes.push_back(operand.getType()); + useDeviceLocs.push_back(operand.getLoc()); + } + for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + useDeviceSymbols.push_back(sym); + } +} + +//===----------------------------------------------------------------------===// +// ClauseProcessor unique clauses +//===----------------------------------------------------------------------===// + +bool ClauseProcessor::processCollapse( + mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval, + llvm::SmallVectorImpl &lowerBound, + llvm::SmallVectorImpl &upperBound, + llvm::SmallVectorImpl &step, + llvm::SmallVectorImpl &iv, + std::size_t &loopVarTypeSize) const { + bool found = false; + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + // Collect the loops to collapse. + Fortran::lower::pft::Evaluation *doConstructEval = + &eval.getFirstNestedEvaluation(); + if (doConstructEval->getIf() + ->IsDoConcurrent()) { + TODO(currentLocation, "Do Concurrent in Worksharing loop construct"); + } + + std::int64_t collapseValue = 1l; + if (auto *collapseClause = findUniqueClause()) { + const auto *expr = Fortran::semantics::GetExpr(collapseClause->v); + collapseValue = Fortran::evaluate::ToInt64(*expr).value(); + found = true; + } + + loopVarTypeSize = 0; + do { + Fortran::lower::pft::Evaluation *doLoop = + &doConstructEval->getFirstNestedEvaluation(); + auto *doStmt = doLoop->getIf(); + assert(doStmt && "Expected do loop to be in the nested evaluation"); + const auto &loopControl = + std::get>(doStmt->t); + const Fortran::parser::LoopControl::Bounds *bounds = + std::get_if(&loopControl->u); + assert(bounds && "Expected bounds for worksharing do loop"); + Fortran::lower::StatementContext stmtCtx; + lowerBound.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))); + upperBound.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))); + if (bounds->step) { + step.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(bounds->step), stmtCtx))); + } else { // If `step` is not present, assume it as `1`. + step.push_back(firOpBuilder.createIntegerConstant( + currentLocation, firOpBuilder.getIntegerType(32), 1)); + } + iv.push_back(bounds->name.thing.symbol); + loopVarTypeSize = std::max(loopVarTypeSize, + bounds->name.thing.symbol->GetUltimate().size()); + collapseValue--; + doConstructEval = + &*std::next(doConstructEval->getNestedEvaluations().begin()); + } while (collapseValue > 0); + + return found; +} + +bool ClauseProcessor::processDefault() const { + if (auto *defaultClause = findUniqueClause()) { + // Private, Firstprivate, Shared, None + switch (defaultClause->v.v) { + case Fortran::parser::OmpDefaultClause::Type::Shared: + case Fortran::parser::OmpDefaultClause::Type::None: + // Default clause with shared or none do not require any handling since + // Shared is the default behavior in the IR and None is only required + // for semantic checks. + break; + case Fortran::parser::OmpDefaultClause::Type::Private: + // TODO Support default(private) + break; + case Fortran::parser::OmpDefaultClause::Type::Firstprivate: + // TODO Support default(firstprivate) + break; + } + return true; + } + return false; +} + +bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const { + const Fortran::parser::CharBlock *source = nullptr; + if (auto *deviceClause = findUniqueClause(&source)) { + mlir::Location clauseLocation = converter.genLocation(*source); + if (auto deviceModifier = std::get< + std::optional>( + deviceClause->v.t)) { + if (deviceModifier == + Fortran::parser::OmpDeviceClause::DeviceModifier::Ancestor) { + TODO(clauseLocation, "OMPD_target Device Modifier Ancestor"); + } + } + if (const auto *deviceExpr = Fortran::semantics::GetExpr( + std::get(deviceClause->v.t))) { + result = fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx)); + } + return true; + } + return false; +} + +bool ClauseProcessor::processDeviceType( + mlir::omp::DeclareTargetDeviceType &result) const { + if (auto *deviceTypeClause = findUniqueClause()) { + // Case: declare target ... device_type(any | host | nohost) + switch (deviceTypeClause->v.v) { + case Fortran::parser::OmpDeviceTypeClause::Type::Nohost: + result = mlir::omp::DeclareTargetDeviceType::nohost; + break; + case Fortran::parser::OmpDeviceTypeClause::Type::Host: + result = mlir::omp::DeclareTargetDeviceType::host; + break; + case Fortran::parser::OmpDeviceTypeClause::Type::Any: + result = mlir::omp::DeclareTargetDeviceType::any; + break; + } + return true; + } + return false; +} + +bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const { + const Fortran::parser::CharBlock *source = nullptr; + if (auto *finalClause = findUniqueClause(&source)) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location clauseLocation = converter.genLocation(*source); + + mlir::Value finalVal = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(finalClause->v), stmtCtx)); + result = firOpBuilder.createConvert(clauseLocation, + firOpBuilder.getI1Type(), finalVal); + return true; + } + return false; +} + +bool ClauseProcessor::processHint(mlir::IntegerAttr &result) const { + if (auto *hintClause = findUniqueClause()) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + const auto *expr = Fortran::semantics::GetExpr(hintClause->v); + int64_t hintValue = *Fortran::evaluate::ToInt64(*expr); + result = firOpBuilder.getI64IntegerAttr(hintValue); + return true; + } + return false; +} + +bool ClauseProcessor::processMergeable(mlir::UnitAttr &result) const { + return markClauseOccurrence(result); +} + +bool ClauseProcessor::processNowait(mlir::UnitAttr &result) const { + return markClauseOccurrence(result); +} + +bool ClauseProcessor::processNumThreads( + Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { + if (auto *numThreadsClause = findUniqueClause()) { + // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`. + result = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); + return true; + } + return false; +} + +bool ClauseProcessor::processOrdered(mlir::IntegerAttr &result) const { + if (auto *orderedClause = findUniqueClause()) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + int64_t orderedClauseValue = 0l; + if (orderedClause->v.has_value()) { + const auto *expr = Fortran::semantics::GetExpr(orderedClause->v); + orderedClauseValue = *Fortran::evaluate::ToInt64(*expr); + } + result = firOpBuilder.getI64IntegerAttr(orderedClauseValue); + return true; + } + return false; +} + +bool ClauseProcessor::processPriority(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const { + if (auto *priorityClause = findUniqueClause()) { + result = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(priorityClause->v), stmtCtx)); + return true; + } + return false; +} + +bool ClauseProcessor::processProcBind( + mlir::omp::ClauseProcBindKindAttr &result) const { + if (auto *procBindClause = findUniqueClause()) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + result = genProcBindKindAttr(firOpBuilder, procBindClause); + return true; + } + return false; +} + +bool ClauseProcessor::processSafelen(mlir::IntegerAttr &result) const { + if (auto *safelenClause = findUniqueClause()) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + const auto *expr = Fortran::semantics::GetExpr(safelenClause->v); + const std::optional safelenVal = + Fortran::evaluate::ToInt64(*expr); + result = firOpBuilder.getI64IntegerAttr(*safelenVal); + return true; + } + return false; +} + +bool ClauseProcessor::processSchedule( + mlir::omp::ClauseScheduleKindAttr &valAttr, + mlir::omp::ScheduleModifierAttr &modifierAttr, + mlir::UnitAttr &simdModifierAttr) const { + if (auto *scheduleClause = findUniqueClause()) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::MLIRContext *context = firOpBuilder.getContext(); + const Fortran::parser::OmpScheduleClause &scheduleType = scheduleClause->v; + const auto &scheduleClauseKind = + std::get( + scheduleType.t); + + mlir::omp::ClauseScheduleKind scheduleKind; + switch (scheduleClauseKind) { + case Fortran::parser::OmpScheduleClause::ScheduleType::Static: + scheduleKind = mlir::omp::ClauseScheduleKind::Static; + break; + case Fortran::parser::OmpScheduleClause::ScheduleType::Dynamic: + scheduleKind = mlir::omp::ClauseScheduleKind::Dynamic; + break; + case Fortran::parser::OmpScheduleClause::ScheduleType::Guided: + scheduleKind = mlir::omp::ClauseScheduleKind::Guided; + break; + case Fortran::parser::OmpScheduleClause::ScheduleType::Auto: + scheduleKind = mlir::omp::ClauseScheduleKind::Auto; + break; + case Fortran::parser::OmpScheduleClause::ScheduleType::Runtime: + scheduleKind = mlir::omp::ClauseScheduleKind::Runtime; + break; + } + + mlir::omp::ScheduleModifier scheduleModifier = + getScheduleModifier(scheduleClause->v); + + if (scheduleModifier != mlir::omp::ScheduleModifier::none) + modifierAttr = + mlir::omp::ScheduleModifierAttr::get(context, scheduleModifier); + + if (getSimdModifier(scheduleClause->v) != mlir::omp::ScheduleModifier::none) + simdModifierAttr = firOpBuilder.getUnitAttr(); + + valAttr = mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind); + return true; + } + return false; +} + +bool ClauseProcessor::processScheduleChunk( + Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { + if (auto *scheduleClause = findUniqueClause()) { + if (const auto &chunkExpr = + std::get>( + scheduleClause->v.t)) { + if (const auto *expr = Fortran::semantics::GetExpr(*chunkExpr)) { + result = fir::getBase(converter.genExprValue(*expr, stmtCtx)); + } + } + return true; + } + return false; +} + +bool ClauseProcessor::processSimdlen(mlir::IntegerAttr &result) const { + if (auto *simdlenClause = findUniqueClause()) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + const auto *expr = Fortran::semantics::GetExpr(simdlenClause->v); + const std::optional simdlenVal = + Fortran::evaluate::ToInt64(*expr); + result = firOpBuilder.getI64IntegerAttr(*simdlenVal); + return true; + } + return false; +} + +bool ClauseProcessor::processThreadLimit( + Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { + if (auto *threadLmtClause = findUniqueClause()) { + result = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(threadLmtClause->v), stmtCtx)); + return true; + } + return false; +} + +bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const { + return markClauseOccurrence(result); +} + +//===----------------------------------------------------------------------===// +// ClauseProcessor repeatable clauses +//===----------------------------------------------------------------------===// + +bool ClauseProcessor::processAllocate( + llvm::SmallVectorImpl &allocatorOperands, + llvm::SmallVectorImpl &allocateOperands) const { + return findRepeatableClause( + [&](const ClauseTy::Allocate *allocateClause, + const Fortran::parser::CharBlock &) { + genAllocateClause(converter, allocateClause->v, allocatorOperands, + allocateOperands); + }); +} + +bool ClauseProcessor::processCopyin() const { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint(); + firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); + + bool hasCopyin = findRepeatableClause( + [&](const ClauseTy::Copyin *copyinClause, + const Fortran::parser::CharBlock &) { + const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v; + for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + if (sym->has()) + TODO(converter.getCurrentLocation(), + "common block in Copyin clause"); + if (Fortran::semantics::IsAllocatableOrPointer(sym->GetUltimate())) + TODO(converter.getCurrentLocation(), + "pointer or allocatable variables in Copyin clause"); + assert(sym->has() && + "No host-association found"); + converter.copyHostAssociateVar(*sym); + } + }); + + // [OMP 5.0, 2.19.6.1] The copy is done after the team is formed and prior to + // the execution of the associated structured block. Emit implicit barrier to + // synchronize threads and avoid data races on propagation master's thread + // values of threadprivate variables to local instances of that variables of + // all other implicit threads. + if (hasCopyin) + firOpBuilder.create(converter.getCurrentLocation()); + firOpBuilder.restoreInsertionPoint(insPt); + return hasCopyin; } -/* When parallel is used in a combined construct, then use this function to - * create the parallel operation. It handles the parallel specific clauses - * and leaves the rest for handling at the inner operations. - * TODO: Refactor clause handling - */ -template -static void -createCombinedParallelOp(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - const Directive &directive) { +bool ClauseProcessor::processDepend( + llvm::SmallVectorImpl &dependTypeOperands, + llvm::SmallVectorImpl &dependOperands) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Location currentLocation = converter.getCurrentLocation(); - Fortran::lower::StatementContext stmtCtx; - llvm::ArrayRef argTy; - mlir::Value ifClauseOperand, numThreadsClauseOperand; - SmallVector allocatorOperands, allocateOperands; - mlir::omp::ClauseProcBindKindAttr procBindKindAttr; - const auto &opClauseList = - std::get(directive.t); - // TODO: Handle the following clauses - // 1. default - // Note: rest of the clauses are handled when the inner operation is created - for (const Fortran::parser::OmpClause &clause : opClauseList.v) { - mlir::Location clauseLocation = converter.genLocation(clause.source); - if (const auto &ifClause = - std::get_if(&clause.u)) { - ifClauseOperand = - getIfClauseOperand(converter, stmtCtx, ifClause, clauseLocation); - } else if (const auto &numThreadsClause = - std::get_if( - &clause.u)) { - numThreadsClauseOperand = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); - } else if (const auto &procBindClause = - std::get_if( - &clause.u)) { - procBindKindAttr = genProcBindKindAttr(firOpBuilder, procBindClause); - } - } - // Create and insert the operation. - auto parallelOp = firOpBuilder.create( - currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand, - allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(), - /*reductions=*/nullptr, procBindKindAttr); - createBodyOfOp(parallelOp, converter, currentLocation, eval, - &opClauseList, /*iv=*/{}, - /*isCombined=*/true); + return findRepeatableClause( + [&](const ClauseTy::Depend *dependClause, + const Fortran::parser::CharBlock &) { + const std::list &depVal = + std::get>( + std::get( + dependClause->v.u) + .t); + mlir::omp::ClauseTaskDependAttr dependTypeOperand = + genDependKindAttr(firOpBuilder, dependClause); + dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(), + dependTypeOperand); + for (const Fortran::parser::Designator &ompObject : depVal) { + Fortran::semantics::Symbol *sym = nullptr; + std::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::DataRef &designator) { + if (const Fortran::parser::Name *name = + std::get_if(&designator.u)) { + sym = name->symbol; + } else if (std::get_if>( + &designator.u)) { + TODO(converter.getCurrentLocation(), + "array sections not supported for task depend"); + } + }, + [&](const Fortran::parser::Substring &designator) { + TODO(converter.getCurrentLocation(), + "substring not supported for task depend"); + }}, + (ompObject).u); + const mlir::Value variable = converter.getSymbolAddress(*sym); + dependOperands.push_back(variable); + } + }); } -/// This function returns the identity value of the operator \p reductionOpName. -/// For example: -/// 0 + x = x, -/// 1 * x = x -static int getOperationIdentity(llvm::StringRef reductionOpName, - mlir::Location loc) { - if (reductionOpName.contains("add") || reductionOpName.contains("or") || - reductionOpName.contains("neqv")) - return 0; - if (reductionOpName.contains("multiply") || reductionOpName.contains("and") || - reductionOpName.contains("eqv")) - return 1; - TODO(loc, "Reduction of some intrinsic operators is not supported"); +bool ClauseProcessor::processIf(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const { + return findRepeatableClause( + [&](const ClauseTy::If *ifClause, + const Fortran::parser::CharBlock &source) { + mlir::Location clauseLocation = converter.genLocation(source); + // TODO Consider DirectiveNameModifier of the `ifClause` to only search + // for an applicable 'if' clause. + result = + getIfClauseOperand(converter, stmtCtx, ifClause, clauseLocation); + }); } -static Value getReductionInitValue(mlir::Location loc, mlir::Type type, - llvm::StringRef reductionOpName, - fir::FirOpBuilder &builder) { - assert((fir::isa_integer(type) || fir::isa_real(type) || - type.isa()) && - "only integer, logical and real types are currently supported"); - if (reductionOpName.contains("max")) { - if (auto ty = type.dyn_cast()) { - const llvm::fltSemantics &sem = ty.getFloatSemantics(); - return builder.createRealConstant( - loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); +bool ClauseProcessor::processLink( + llvm::SmallVectorImpl &result) const { + return findRepeatableClause( + [&](const ClauseTy::Link *linkClause, + const Fortran::parser::CharBlock &) { + // Case: declare target link(var1, var2)... + gatherFuncAndVarSyms( + linkClause->v, mlir::omp::DeclareTargetCaptureClause::link, result); + }); +} + +bool ClauseProcessor::processMap( + llvm::SmallVectorImpl &mapOperands, + llvm::SmallVectorImpl &mapTypes) const { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + return findRepeatableClause< + ClauseTy::Map>([&](const ClauseTy::Map *mapClause, + const Fortran::parser::CharBlock &source) { + mlir::Location clauseLocation = converter.genLocation(source); + const auto &oMapType = + std::get>(mapClause->v.t); + llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + // If the map type is specified, then process it else Tofrom is the default. + if (oMapType) { + const Fortran::parser::OmpMapType::Type &mapType = + std::get(oMapType->t); + switch (mapType) { + case Fortran::parser::OmpMapType::Type::To: + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; + break; + case Fortran::parser::OmpMapType::Type::From: + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + break; + case Fortran::parser::OmpMapType::Type::Tofrom: + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + break; + case Fortran::parser::OmpMapType::Type::Alloc: + case Fortran::parser::OmpMapType::Type::Release: + // alloc and release is the default map_type for the Target Data Ops, + // i.e. if no bits for map_type is supplied then alloc/release is + // implicitly assumed based on the target directive. Default value for + // Target Data and Enter Data is alloc and for Exit Data it is release. + break; + case Fortran::parser::OmpMapType::Type::Delete: + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; + } + + if (std::get>( + oMapType->t)) + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; + } else { + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; } - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, minInt); - } else if (reductionOpName.contains("min")) { - if (auto ty = type.dyn_cast()) { - const llvm::fltSemantics &sem = ty.getFloatSemantics(); - return builder.createRealConstant( - loc, type, llvm::APFloat::getSmallest(sem, /*Negative=*/true)); + + // TODO: Add support MapTypeModifiers close, mapper, present, iterator + + mlir::IntegerAttr mapTypeAttr = firOpBuilder.getIntegerAttr( + firOpBuilder.getI64Type(), + static_cast< + std::underlying_type_t>( + mapTypeBits)); + + llvm::SmallVector mapOperand; + // Check for unsupported map operand types. + for (const Fortran::parser::OmpObject &ompObject : + std::get(mapClause->v.t).v) { + if (Fortran::parser::Unwrap(ompObject) || + Fortran::parser::Unwrap( + ompObject)) + TODO(clauseLocation, + "OMPD_target_data for Array Expressions or Structure Components"); } - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, maxInt); - } else if (reductionOpName.contains("ior")) { - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, zeroInt); - } else if (reductionOpName.contains("ieor")) { - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, zeroInt); - } else if (reductionOpName.contains("iand")) { - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, allOnInt); - } else { - if (type.isa()) - return builder.create( - loc, type, - builder.getFloatAttr( - type, (double)getOperationIdentity(reductionOpName, loc))); + genObjectList(std::get(mapClause->v.t), + converter, mapOperand); - if (type.isa()) { - Value intConst = builder.create( - loc, builder.getI1Type(), - builder.getIntegerAttr(builder.getI1Type(), - getOperationIdentity(reductionOpName, loc))); - return builder.createConvert(loc, type, intConst); + for (mlir::Value mapOp : mapOperand) { + checkMapType(mapOp.getLoc(), mapOp.getType()); + mapOperands.push_back(mapOp); + mapTypes.push_back(mapTypeAttr); } + }); +} + +bool ClauseProcessor::processReduction( + mlir::Location currentLocation, + llvm::SmallVectorImpl &reductionVars, + llvm::SmallVectorImpl &reductionDeclSymbols) const { + return findRepeatableClause( + [&](const ClauseTy::Reduction *reductionClause, + const Fortran::parser::CharBlock &) { + addReductionDecl(currentLocation, converter, reductionClause->v, + reductionVars, reductionDeclSymbols); + }); +} - return builder.create( - loc, type, - builder.getIntegerAttr(type, - getOperationIdentity(reductionOpName, loc))); - } +bool ClauseProcessor::processSectionsReduction( + mlir::Location currentLocation) const { + return findRepeatableClause( + [&](const ClauseTy::Reduction *, const Fortran::parser::CharBlock &) { + TODO(currentLocation, "OMPC_Reduction"); + }); } -template -static Value getReductionOperation(fir::FirOpBuilder &builder, mlir::Type type, - mlir::Location loc, mlir::Value op1, - mlir::Value op2) { - assert(type.isIntOrIndexOrFloat() && - "only integer and float types are currently supported"); - if (type.isIntOrIndex()) - return builder.create(loc, op1, op2); - return builder.create(loc, op1, op2); +bool ClauseProcessor::processTo( + llvm::SmallVectorImpl &result) const { + return findRepeatableClause( + [&](const ClauseTy::To *toClause, const Fortran::parser::CharBlock &) { + // Case: declare target to(func, var1, var2)... + gatherFuncAndVarSyms(toClause->v, + mlir::omp::DeclareTargetCaptureClause::to, result); + }); } -static omp::ReductionDeclareOp -createMinimalReductionDecl(fir::FirOpBuilder &builder, - llvm::StringRef reductionOpName, mlir::Type type, - mlir::Location loc) { - mlir::ModuleOp module = builder.getModule(); - mlir::OpBuilder modBuilder(module.getBodyRegion()); +bool ClauseProcessor::processUseDeviceAddr( + llvm::SmallVectorImpl &operands, + llvm::SmallVectorImpl &useDeviceTypes, + llvm::SmallVectorImpl &useDeviceLocs, + llvm::SmallVectorImpl &useDeviceSymbols) + const { + return findRepeatableClause( + [&](const ClauseTy::UseDeviceAddr *devAddrClause, + const Fortran::parser::CharBlock &) { + addUseDeviceClause(converter, devAddrClause->v, operands, + useDeviceTypes, useDeviceLocs, useDeviceSymbols); + }); +} - mlir::omp::ReductionDeclareOp decl = - modBuilder.create(loc, reductionOpName, type); - builder.createBlock(&decl.getInitializerRegion(), - decl.getInitializerRegion().end(), {type}, {loc}); - builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); - Value init = getReductionInitValue(loc, type, reductionOpName, builder); - builder.create(loc, init); +bool ClauseProcessor::processUseDevicePtr( + llvm::SmallVectorImpl &operands, + llvm::SmallVectorImpl &useDeviceTypes, + llvm::SmallVectorImpl &useDeviceLocs, + llvm::SmallVectorImpl &useDeviceSymbols) + const { + return findRepeatableClause( + [&](const ClauseTy::UseDevicePtr *devPtrClause, + const Fortran::parser::CharBlock &) { + addUseDeviceClause(converter, devPtrClause->v, operands, useDeviceTypes, + useDeviceLocs, useDeviceSymbols); + }); +} - builder.createBlock(&decl.getReductionRegion(), - decl.getReductionRegion().end(), {type, type}, - {loc, loc}); +//===----------------------------------------------------------------------===// +// Code generation helper functions +//===----------------------------------------------------------------------===// + +static fir::GlobalOp globalInitialization( + Fortran::lower::AbstractConverter &converter, + fir::FirOpBuilder &firOpBuilder, const Fortran::semantics::Symbol &sym, + const Fortran::lower::pft::Variable &var, mlir::Location currentLocation) { + mlir::Type ty = converter.genType(sym); + std::string globalName = converter.mangleName(sym); + mlir::StringAttr linkage = firOpBuilder.createInternalLinkage(); + fir::GlobalOp global = + firOpBuilder.createGlobal(currentLocation, ty, globalName, linkage); + + // Create default initialization for non-character scalar. + if (Fortran::semantics::IsAllocatableOrPointer(sym)) { + mlir::Type baseAddrType = ty.dyn_cast().getEleTy(); + Fortran::lower::createGlobalInitialization( + firOpBuilder, global, [&](fir::FirOpBuilder &b) { + mlir::Value nullAddr = + b.createNullConstant(currentLocation, baseAddrType); + mlir::Value box = + b.create(currentLocation, ty, nullAddr); + b.create(currentLocation, box); + }); + } else { + Fortran::lower::createGlobalInitialization( + firOpBuilder, global, [&](fir::FirOpBuilder &b) { + mlir::Value undef = b.create(currentLocation, ty); + b.create(currentLocation, undef); + }); + } + + return global; +} + +static mlir::Operation *getCompareFromReductionOp(mlir::Operation *reductionOp, + mlir::Value loadVal) { + for (mlir::Value reductionOperand : reductionOp->getOperands()) { + if (mlir::Operation *compareOp = reductionOperand.getDefiningOp()) { + if (compareOp->getOperand(0) == loadVal || + compareOp->getOperand(1) == loadVal) + assert((mlir::isa(compareOp) || + mlir::isa(compareOp)) && + "Expected comparison not found in reduction intrinsic"); + return compareOp; + } + } + return nullptr; +} + +/// The COMMON block is a global structure. \p commonValue is the base address +/// of the the COMMON block. As the offset from the symbol \p sym, generate the +/// COMMON block member value (commonValue + offset) for the symbol. +/// FIXME: Share the code with `instantiateCommon` in ConvertVariable.cpp. +static mlir::Value +genCommonBlockMember(Fortran::lower::AbstractConverter &converter, + const Fortran::semantics::Symbol &sym, + mlir::Value commonValue) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); + mlir::IntegerType i8Ty = firOpBuilder.getIntegerType(8); + mlir::Type i8Ptr = firOpBuilder.getRefType(i8Ty); + mlir::Type seqTy = firOpBuilder.getRefType(firOpBuilder.getVarLenSeqTy(i8Ty)); + mlir::Value base = + firOpBuilder.createConvert(currentLocation, seqTy, commonValue); + std::size_t byteOffset = sym.GetUltimate().offset(); + mlir::Value offs = firOpBuilder.createIntegerConstant( + currentLocation, firOpBuilder.getIndexType(), byteOffset); + mlir::Value varAddr = firOpBuilder.create( + currentLocation, i8Ptr, base, mlir::ValueRange{offs}); + mlir::Type symType = converter.genType(sym); + return firOpBuilder.createConvert(currentLocation, + firOpBuilder.getRefType(symType), varAddr); +} - return decl; +// Get the extended value for \p val by extracting additional variable +// information from \p base. +static fir::ExtendedValue getExtendedValue(fir::ExtendedValue base, + mlir::Value val) { + return base.match( + [&](const fir::MutableBoxValue &box) -> fir::ExtendedValue { + return fir::MutableBoxValue(val, box.nonDeferredLenParams(), {}); + }, + [&](const auto &) -> fir::ExtendedValue { + return fir::substBase(base, val); + }); } -/// Creates an OpenMP reduction declaration and inserts it into the provided -/// symbol table. The declaration has a constant initializer with the neutral -/// value `initValue`, and the reduction combiner carried over from `reduce`. -/// TODO: Generalize this for non-integer types, add atomic region. -static omp::ReductionDeclareOp -createReductionDecl(fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, - const Fortran::parser::ProcedureDesignator &procDesignator, - mlir::Type type, mlir::Location loc) { - OpBuilder::InsertionGuard guard(builder); - mlir::ModuleOp module = builder.getModule(); +static void threadPrivatizeVars(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); + mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint(); + firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); - auto decl = - module.lookupSymbol(reductionOpName); - if (decl) - return decl; + // Get the original ThreadprivateOp corresponding to the symbol and use the + // symbol value from that opeartion to create one ThreadprivateOp copy + // operation inside the parallel region. + auto genThreadprivateOp = [&](Fortran::lower::SymbolRef sym) -> mlir::Value { + mlir::Value symOriThreadprivateValue = converter.getSymbolAddress(sym); + mlir::Operation *op = symOriThreadprivateValue.getDefiningOp(); + assert(mlir::isa(op) && + "The threadprivate operation not created"); + mlir::Value symValue = + mlir::dyn_cast(op).getSymAddr(); + return firOpBuilder.create( + currentLocation, symValue.getType(), symValue); + }; - decl = createMinimalReductionDecl(builder, reductionOpName, type, loc); - builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); - mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); - mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); + llvm::SetVector threadprivateSyms; + converter.collectSymbolSet( + eval, threadprivateSyms, + Fortran::semantics::Symbol::Flag::OmpThreadprivate); + std::set threadprivateSymNames; - Value reductionOp; - if (const auto *name{ - Fortran::parser::Unwrap(procDesignator)}) { - if (name->source == "max") { - reductionOp = - getReductionOperation( - builder, type, loc, op1, op2); - } else if (name->source == "min") { - reductionOp = - getReductionOperation( - builder, type, loc, op1, op2); - } else if (name->source == "ior") { - assert((type.isIntOrIndex()) && "only integer is expected"); - reductionOp = builder.create(loc, op1, op2); - } else if (name->source == "ieor") { - assert((type.isIntOrIndex()) && "only integer is expected"); - reductionOp = builder.create(loc, op1, op2); - } else if (name->source == "iand") { - assert((type.isIntOrIndex()) && "only integer is expected"); - reductionOp = builder.create(loc, op1, op2); + // For a COMMON block, the ThreadprivateOp is generated for itself instead of + // its members, so only bind the value of the new copied ThreadprivateOp + // inside the parallel region to the common block symbol only once for + // multiple members in one COMMON block. + llvm::SetVector commonSyms; + for (std::size_t i = 0; i < threadprivateSyms.size(); i++) { + const Fortran::semantics::Symbol *sym = threadprivateSyms[i]; + mlir::Value symThreadprivateValue; + // The variable may be used more than once, and each reference has one + // symbol with the same name. Only do once for references of one variable. + if (threadprivateSymNames.find(sym->name()) != threadprivateSymNames.end()) + continue; + threadprivateSymNames.insert(sym->name()); + if (const Fortran::semantics::Symbol *common = + Fortran::semantics::FindCommonBlockContaining(sym->GetUltimate())) { + mlir::Value commonThreadprivateValue; + if (commonSyms.contains(common)) { + commonThreadprivateValue = converter.getSymbolAddress(*common); + } else { + commonThreadprivateValue = genThreadprivateOp(*common); + converter.bindSymbol(*common, commonThreadprivateValue); + commonSyms.insert(common); + } + symThreadprivateValue = + genCommonBlockMember(converter, *sym, commonThreadprivateValue); } else { - TODO(loc, "Reduction of some intrinsic operators is not supported"); + symThreadprivateValue = genThreadprivateOp(*sym); } + + fir::ExtendedValue sexv = converter.getSymbolExtendedValue(*sym); + fir::ExtendedValue symThreadprivateExv = + getExtendedValue(sexv, symThreadprivateValue); + converter.bindSymbol(*sym, symThreadprivateExv); } - builder.create(loc, reductionOp); - return decl; + firOpBuilder.restoreInsertionPoint(insPt); } -/// Creates an OpenMP reduction declaration and inserts it into the provided -/// symbol table. The declaration has a constant initializer with the neutral -/// value `initValue`, and the reduction combiner carried over from `reduce`. -/// TODO: Generalize this for non-integer types, add atomic region. -static omp::ReductionDeclareOp createReductionDecl( - fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, - mlir::Type type, mlir::Location loc) { - OpBuilder::InsertionGuard guard(builder); - mlir::ModuleOp module = builder.getModule(); - - auto decl = - module.lookupSymbol(reductionOpName); - if (decl) - return decl; - - decl = createMinimalReductionDecl(builder, reductionOpName, type, loc); - builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); - mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); - mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); - - Value reductionOp; - switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: - reductionOp = - getReductionOperation( - builder, type, loc, op1, op2); - break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: - reductionOp = - getReductionOperation( - builder, type, loc, op1, op2); - break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: { - Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); - Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); - - Value andiOp = builder.create(loc, op1I1, op2I1); +static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter, + std::size_t loopVarTypeSize) { + // OpenMP runtime requires 32-bit or 64-bit loop variables. + loopVarTypeSize = loopVarTypeSize * 8; + if (loopVarTypeSize < 32) { + loopVarTypeSize = 32; + } else if (loopVarTypeSize > 64) { + loopVarTypeSize = 64; + mlir::emitWarning(converter.getCurrentLocation(), + "OpenMP loop iteration variable cannot have more than 64 " + "bits size and will be narrowed into 64 bits."); + } + assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) && + "OpenMP loop iteration variable size must be transformed into 32-bit " + "or 64-bit"); + return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize); +} - reductionOp = builder.createConvert(loc, type, andiOp); - break; +/// Create empty blocks for the current region. +/// These blocks replace blocks parented to an enclosing region. +static void createEmptyRegionBlocks( + fir::FirOpBuilder &firOpBuilder, + std::list &evaluationList) { + mlir::Region *region = &firOpBuilder.getRegion(); + for (Fortran::lower::pft::Evaluation &eval : evaluationList) { + if (eval.block) { + if (eval.block->empty()) { + eval.block->erase(); + eval.block = firOpBuilder.createBlock(region); + } else { + [[maybe_unused]] mlir::Operation &terminatorOp = eval.block->back(); + assert((mlir::isa(terminatorOp) || + mlir::isa(terminatorOp)) && + "expected terminator op"); + } + } + if (!eval.isDirective() && eval.hasNestedEvaluations()) + createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations()); } - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: { - Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); - Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); +} - Value oriOp = builder.create(loc, op1I1, op2I1); +static void resetBeforeTerminator(fir::FirOpBuilder &firOpBuilder, + mlir::Operation *storeOp, + mlir::Block &block) { + if (storeOp) + firOpBuilder.setInsertionPointAfter(storeOp); + else + firOpBuilder.setInsertionPointToStart(&block); +} - reductionOp = builder.createConvert(loc, type, oriOp); - break; +/// Create the body (block) for an OpenMP Operation. +/// +/// \param [in] op - the operation the body belongs to. +/// \param [inout] converter - converter to use for the clauses. +/// \param [in] loc - location in source code. +/// \param [in] eval - current PFT node/evaluation. +/// \oaran [in] clauses - list of clauses to process. +/// \param [in] args - block arguments (induction variable[s]) for the +//// region. +/// \param [in] outerCombined - is this an outer operation - prevents +/// privatization. +template +static void createBodyOfOp( + Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OmpClauseList *clauses = nullptr, + const llvm::SmallVector &args = {}, + bool outerCombined = false, DataSharingProcessor *dsp = nullptr) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + // If an argument for the region is provided then create the block with that + // argument. Also update the symbol's address with the mlir argument value. + // e.g. For loops the argument is the induction variable. And all further + // uses of the induction variable should use this mlir value. + mlir::Operation *storeOp = nullptr; + if (args.size()) { + std::size_t loopVarTypeSize = 0; + for (const Fortran::semantics::Symbol *arg : args) + loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size()); + mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); + llvm::SmallVector tiv; + llvm::SmallVector locs; + for (int i = 0; i < (int)args.size(); i++) { + tiv.push_back(loopVarType); + locs.push_back(loc); + } + firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs); + int argIndex = 0; + // The argument is not currently in memory, so make a temporary for the + // argument, and store it there, then bind that location to the argument. + for (const Fortran::semantics::Symbol *arg : args) { + mlir::Value val = + fir::getBase(op.getRegion().front().getArgument(argIndex)); + mlir::Value temp = firOpBuilder.createTemporary( + loc, loopVarType, + llvm::ArrayRef{ + Fortran::lower::getAdaptToByRefAttr(firOpBuilder)}); + storeOp = firOpBuilder.create(loc, val, temp); + converter.bindSymbol(*arg, temp); + argIndex++; + } + } else { + firOpBuilder.createBlock(&op.getRegion()); } - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: { - Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); - Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); + // Set the insert for the terminator operation to go at the end of the + // block - this is either empty or the block with the stores above, + // the end of the block works for both. + mlir::Block &block = op.getRegion().back(); + firOpBuilder.setInsertionPointToEnd(&block); - Value cmpiOp = builder.create( - loc, arith::CmpIPredicate::eq, op1I1, op2I1); + // If it is an unstructured region and is not the outer region of a combined + // construct, create empty blocks for all evaluations. + if (eval.lowerAsUnstructured() && !outerCombined) + createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations()); - reductionOp = builder.createConvert(loc, type, cmpiOp); - break; + // Insert the terminator. + if constexpr (std::is_same_v || + std::is_same_v) { + mlir::ValueRange results; + firOpBuilder.create(loc, results); + } else { + firOpBuilder.create(loc); } - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: { - Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); - Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); + // Reset the insert point to before the terminator. + resetBeforeTerminator(firOpBuilder, storeOp, block); - Value cmpiOp = builder.create( - loc, arith::CmpIPredicate::ne, op1I1, op2I1); + // Handle privatization. Do not privatize if this is the outer operation. + if (clauses && !outerCombined) { + constexpr bool is_loop = std::is_same_v || + std::is_same_v; + if (!dsp) { + DataSharingProcessor proc(converter, *clauses, eval); + proc.processStep1(); + proc.processStep2(op, is_loop); + } else { + dsp->processStep2(op, is_loop); + } - reductionOp = builder.createConvert(loc, type, cmpiOp); - break; - } - default: - TODO(loc, "Reduction of some intrinsic operators is not supported"); + if (storeOp) + firOpBuilder.setInsertionPointAfter(storeOp); } - builder.create(loc, reductionOp); - return decl; -} - -static mlir::omp::ScheduleModifier -translateModifier(const Fortran::parser::OmpScheduleModifierType &m) { - switch (m.v) { - case Fortran::parser::OmpScheduleModifierType::ModType::Monotonic: - return mlir::omp::ScheduleModifier::monotonic; - case Fortran::parser::OmpScheduleModifierType::ModType::Nonmonotonic: - return mlir::omp::ScheduleModifier::nonmonotonic; - case Fortran::parser::OmpScheduleModifierType::ModType::Simd: - return mlir::omp::ScheduleModifier::simd; + if constexpr (std::is_same_v) { + threadPrivatizeVars(converter, eval); + if (clauses) + ClauseProcessor(converter, *clauses).processCopyin(); } - return mlir::omp::ScheduleModifier::none; } -static mlir::omp::ScheduleModifier -getScheduleModifier(const Fortran::parser::OmpScheduleClause &x) { - const auto &modifier = - std::get>(x.t); - // The input may have the modifier any order, so we look for one that isn't - // SIMD. If modifier is not set at all, fall down to the bottom and return - // "none". - if (modifier) { - const auto &modType1 = - std::get(modifier->t); - if (modType1.v.v == - Fortran::parser::OmpScheduleModifierType::ModType::Simd) { - const auto &modType2 = std::get< - std::optional>( - modifier->t); - if (modType2 && - modType2->v.v != - Fortran::parser::OmpScheduleModifierType::ModType::Simd) - return translateModifier(modType2->v); +static void +createBodyOfTargetOp(Fortran::lower::AbstractConverter &converter, + mlir::omp::DataOp &dataOp, + const llvm::SmallVector &useDeviceTypes, + const llvm::SmallVector &useDeviceLocs, + const llvm::SmallVector + &useDeviceSymbols, + const mlir::Location ¤tLocation) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Region ®ion = dataOp.getRegion(); - return mlir::omp::ScheduleModifier::none; - } + firOpBuilder.createBlock(®ion, {}, useDeviceTypes, useDeviceLocs); + firOpBuilder.create(currentLocation); + firOpBuilder.setInsertionPointToStart(®ion.front()); - return translateModifier(modType1.v); + unsigned argIndex = 0; + for (const Fortran::semantics::Symbol *sym : useDeviceSymbols) { + const mlir::BlockArgument &arg = region.front().getArgument(argIndex); + mlir::Value val = fir::getBase(arg); + fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym); + if (auto refType = val.getType().dyn_cast()) { + if (fir::isa_builtin_cptr_type(refType.getElementType())) { + converter.bindSymbol(*sym, val); + } else { + extVal.match( + [&](const fir::MutableBoxValue &mbv) { + converter.bindSymbol( + *sym, + fir::MutableBoxValue( + val, fir::factory::getNonDeferredLenParams(extVal), {})); + }, + [&](const auto &) { + TODO(converter.getCurrentLocation(), + "use_device clause operand unsupported type"); + }); + } + } else { + TODO(converter.getCurrentLocation(), + "use_device clause operand unsupported type"); + } + argIndex++; } - return mlir::omp::ScheduleModifier::none; } -static mlir::omp::ScheduleModifier -getSIMDModifier(const Fortran::parser::OmpScheduleClause &x) { - const auto &modifier = - std::get>(x.t); - // Either of the two possible modifiers in the input can be the SIMD modifier, - // so look in either one, and return simd if we find one. Not found = return - // "none". - if (modifier) { - const auto &modType1 = - std::get(modifier->t); - if (modType1.v.v == Fortran::parser::OmpScheduleModifierType::ModType::Simd) - return mlir::omp::ScheduleModifier::simd; +static void createTargetOp(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpClauseList &opClauseList, + const llvm::omp::Directive &directive, + mlir::Location currentLocation, + Fortran::lower::pft::Evaluation *eval = nullptr) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + Fortran::lower::StatementContext stmtCtx; + mlir::Value ifClauseOperand, deviceOperand, threadLmtOperand; + mlir::UnitAttr nowaitAttr; + llvm::SmallVector mapOperands, devicePtrOperands, + deviceAddrOperands; + llvm::SmallVector mapTypes; + llvm::SmallVector useDeviceTypes; + llvm::SmallVector useDeviceLocs; + llvm::SmallVector useDeviceSymbols; + + ClauseProcessor cp(converter, opClauseList); + cp.processIf(stmtCtx, ifClauseOperand); + cp.processDevice(stmtCtx, deviceOperand); + cp.processThreadLimit(stmtCtx, threadLmtOperand); + cp.processNowait(nowaitAttr); + cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs, + useDeviceSymbols); + cp.processUseDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs, + useDeviceSymbols); + cp.processMap(mapOperands, mapTypes); - const auto &modType2 = std::get< - std::optional>( - modifier->t); - if (modType2 && modType2->v.v == - Fortran::parser::OmpScheduleModifierType::ModType::Simd) - return mlir::omp::ScheduleModifier::simd; + for (const Fortran::parser::OmpClause &clause : opClauseList.v) { + if (!std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u)) { + mlir::Location clauseLocation = converter.genLocation(clause.source); + TODO(clauseLocation, "OMPD_target unhandled clause"); + } } - return mlir::omp::ScheduleModifier::none; -} - -static std::string getReductionName(llvm::StringRef name, mlir::Type ty) { - return (llvm::Twine(name) + - (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) + - llvm::Twine(ty.getIntOrFloatBitWidth())) - .str(); -} -static std::string getReductionName( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, - mlir::Type ty) { - std::string reductionName; + llvm::SmallVector mapTypesAttr(mapTypes.begin(), + mapTypes.end()); + mlir::ArrayAttr mapTypesArrayAttr = + mlir::ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr); - switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: - reductionName = "add_reduction"; - break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: - reductionName = "multiply_reduction"; - break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: - return "and_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: - return "eqv_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: - return "or_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: - return "neqv_reduction"; - default: - reductionName = "other_reduction"; - break; + if (directive == llvm::omp::Directive::OMPD_target) { + auto targetOp = firOpBuilder.create( + currentLocation, ifClauseOperand, deviceOperand, threadLmtOperand, + nowaitAttr, mapOperands, mapTypesArrayAttr); + createBodyOfOp(targetOp, converter, currentLocation, *eval, &opClauseList); + } else if (directive == llvm::omp::Directive::OMPD_target_data) { + auto dataOp = firOpBuilder.create( + currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands, + deviceAddrOperands, mapOperands, mapTypesArrayAttr); + createBodyOfTargetOp(converter, dataOp, useDeviceTypes, useDeviceLocs, + useDeviceSymbols, currentLocation); + } else if (directive == llvm::omp::Directive::OMPD_target_enter_data) { + firOpBuilder.create( + currentLocation, ifClauseOperand, deviceOperand, nowaitAttr, + mapOperands, mapTypesArrayAttr); + } else if (directive == llvm::omp::Directive::OMPD_target_exit_data) { + firOpBuilder.create(currentLocation, ifClauseOperand, + deviceOperand, nowaitAttr, + mapOperands, mapTypesArrayAttr); + } else { + TODO(currentLocation, "OMPD_target directive unknown"); } - - return getReductionName(reductionName, ty); } -/// Creates a reduction declaration and associates it with an -/// OpenMP block directive -static void -addReductionDecl(mlir::Location currentLocation, - Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpReductionClause &reduction, - SmallVector &reductionVars, - SmallVector &reductionDeclSymbols) { +static void genOMP(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPSimpleStandaloneConstruct + &simpleStandaloneConstruct) { + const auto &directive = + std::get( + simpleStandaloneConstruct.t); fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - omp::ReductionDeclareOp decl; - const auto &redOperator{ - std::get(reduction.t)}; - const auto &objectList{std::get(reduction.t)}; - if (const auto &redDefinedOp = - std::get_if(&redOperator.u)) { - const auto &intrinsicOp{ - std::get( - redDefinedOp->u)}; - switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: - break; + const auto &opClauseList = + std::get(simpleStandaloneConstruct.t); + mlir::Location currentLocation = converter.genLocation(directive.source); - default: - TODO(currentLocation, - "Reduction of some intrinsic operators is not supported"); - break; - } - for (const auto &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap(ompObject)}) { - if (const auto *symbol{name->symbol}) { - mlir::Value symVal = converter.getSymbolAddress(*symbol); - mlir::Type redType = - symVal.getType().cast().getEleTy(); - reductionVars.push_back(symVal); - if (redType.isa()) - decl = createReductionDecl( - firOpBuilder, - getReductionName(intrinsicOp, firOpBuilder.getI1Type()), - intrinsicOp, redType, currentLocation); - else if (redType.isIntOrIndexOrFloat()) { - decl = createReductionDecl(firOpBuilder, - getReductionName(intrinsicOp, redType), - intrinsicOp, redType, currentLocation); - } else { - TODO(currentLocation, "Reduction of some types is not supported"); - } - reductionDeclSymbols.push_back( - SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName())); - } - } - } - } else if (auto reductionIntrinsic = - std::get_if( - &redOperator.u)) { - if (const auto *name{Fortran::parser::Unwrap( - reductionIntrinsic)}) { - if ((name->source != "max") && (name->source != "min") && - (name->source != "ior") && (name->source != "ieor") && - (name->source != "iand")) { - TODO(currentLocation, - "Reduction of intrinsic procedures is not supported"); - } - std::string intrinsicOp = name->ToString(); - for (const auto &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap(ompObject)}) { - if (const auto *symbol{name->symbol}) { - mlir::Value symVal = converter.getSymbolAddress(*symbol); - mlir::Type redType = - symVal.getType().cast().getEleTy(); - reductionVars.push_back(symVal); - assert(redType.isIntOrIndexOrFloat() && - "Unsupported reduction type"); - decl = createReductionDecl( - firOpBuilder, getReductionName(intrinsicOp, redType), - *reductionIntrinsic, redType, currentLocation); - reductionDeclSymbols.push_back(SymbolRefAttr::get( - firOpBuilder.getContext(), decl.getSymName())); - } - } - } - } + switch (directive.v) { + default: + break; + case llvm::omp::Directive::OMPD_barrier: + firOpBuilder.create(currentLocation); + break; + case llvm::omp::Directive::OMPD_taskwait: + firOpBuilder.create(currentLocation); + break; + case llvm::omp::Directive::OMPD_taskyield: + firOpBuilder.create(currentLocation); + break; + case llvm::omp::Directive::OMPD_target_data: + case llvm::omp::Directive::OMPD_target_enter_data: + case llvm::omp::Directive::OMPD_target_exit_data: + createTargetOp(converter, opClauseList, directive.v, currentLocation); + break; + case llvm::omp::Directive::OMPD_target_update: + TODO(currentLocation, "OMPD_target_update"); + case llvm::omp::Directive::OMPD_ordered: + TODO(currentLocation, "OMPD_ordered"); } } +static void +genOMP(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) { + std::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct + &simpleStandaloneConstruct) { + genOMP(converter, eval, simpleStandaloneConstruct); + }, + [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) { + llvm::SmallVector operandRange; + if (const auto &ompObjectList = + std::get>( + flushConstruct.t)) + genObjectList(*ompObjectList, converter, operandRange); + const auto &memOrderClause = std::get>>( + flushConstruct.t); + if (memOrderClause && memOrderClause->size() > 0) + TODO(converter.getCurrentLocation(), + "Handle OmpMemoryOrderClause"); + converter.getFirOpBuilder().create( + converter.getCurrentLocation(), operandRange); + }, + [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) { + TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); + }, + [&](const Fortran::parser::OpenMPCancellationPointConstruct + &cancellationPointConstruct) { + TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); + }, + }, + standaloneConstruct.u); +} + +/* When parallel is used in a combined construct, then use this function to + * create the parallel operation. It handles the parallel specific clauses + * and leaves the rest for handling at the inner operations. + */ +template +static void +createCombinedParallelOp(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Directive &directive) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); + Fortran::lower::StatementContext stmtCtx; + llvm::ArrayRef argTy; + mlir::Value ifClauseOperand, numThreadsClauseOperand; + llvm::SmallVector allocatorOperands, allocateOperands; + mlir::omp::ClauseProcBindKindAttr procBindKindAttr; + const auto &opClauseList = + std::get(directive.t); + // TODO: Handle the following clauses + // 1. default + // Note: rest of the clauses are handled when the inner operation is created + ClauseProcessor cp(converter, opClauseList); + cp.processIf(stmtCtx, ifClauseOperand); + cp.processNumThreads(stmtCtx, numThreadsClauseOperand); + cp.processProcBind(procBindKindAttr); + + // Create and insert the operation. + auto parallelOp = firOpBuilder.create( + currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand, + allocateOperands, allocatorOperands, + /*reduction_vars=*/mlir::ValueRange(), + /*reductions=*/nullptr, procBindKindAttr); + + createBodyOfOp(parallelOp, converter, currentLocation, + eval, &opClauseList, /*iv=*/{}, + /*isCombined=*/true); +} + static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); llvm::SmallVector lowerBound, upperBound, step, linearVars, linearStepVars, reductionVars, alignedVars, nontemporalVars; mlir::Value scheduleChunkClauseOperand, ifClauseOperand; - mlir::Attribute scheduleClauseOperand, noWaitClauseOperand, - orderedClauseOperand, orderClauseOperand; + mlir::IntegerAttr orderedClauseOperand; + mlir::omp::ClauseOrderKindAttr orderClauseOperand; + mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand; + mlir::omp::ScheduleModifierAttr scheduleModClauseOperand; + mlir::UnitAttr nowaitClauseOperand, scheduleSimdClauseOperand; mlir::IntegerAttr simdlenClauseOperand, safelenClauseOperand; - SmallVector reductionDeclSymbols; + llvm::SmallVector reductionDeclSymbols; Fortran::lower::StatementContext stmtCtx; - const auto &loopOpClauseList = std::get( - std::get(loopConstruct.t).t); + std::size_t loopVarTypeSize; + llvm::SmallVector iv; const auto &beginLoopDirective = std::get(loopConstruct.t); + const auto &loopOpClauseList = + std::get(beginLoopDirective.t); mlir::Location currentLocation = converter.genLocation(beginLoopDirective.source); const auto ompDirective = @@ -1625,84 +2285,14 @@ DataSharingProcessor dsp(converter, loopOpClauseList, eval); dsp.processStep1(); - // Collect the loops to collapse. - auto *doConstructEval = &eval.getFirstNestedEvaluation(); - if (doConstructEval->getIf() - ->IsDoConcurrent()) { - TODO(currentLocation, "Do Concurrent in Worksharing loop construct"); - } - - std::int64_t collapseValue = - Fortran::lower::getCollapseValue(loopOpClauseList); - std::size_t loopVarTypeSize = 0; - SmallVector iv; - do { - auto *doLoop = &doConstructEval->getFirstNestedEvaluation(); - auto *doStmt = doLoop->getIf(); - assert(doStmt && "Expected do loop to be in the nested evaluation"); - const auto &loopControl = - std::get>(doStmt->t); - const Fortran::parser::LoopControl::Bounds *bounds = - std::get_if(&loopControl->u); - assert(bounds && "Expected bounds for worksharing do loop"); - Fortran::lower::StatementContext stmtCtx; - lowerBound.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))); - upperBound.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))); - if (bounds->step) { - step.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(bounds->step), stmtCtx))); - } else { // If `step` is not present, assume it as `1`. - step.push_back(firOpBuilder.createIntegerConstant( - currentLocation, firOpBuilder.getIntegerType(32), 1)); - } - iv.push_back(bounds->name.thing.symbol); - loopVarTypeSize = std::max(loopVarTypeSize, - bounds->name.thing.symbol->GetUltimate().size()); - - collapseValue--; - doConstructEval = - &*std::next(doConstructEval->getNestedEvaluations().begin()); - } while (collapseValue > 0); - - for (const auto &clause : loopOpClauseList.v) { - mlir::Location clauseLocation = converter.genLocation(clause.source); - if (const auto &scheduleClause = - std::get_if(&clause.u)) { - if (const auto &chunkExpr = - std::get>( - scheduleClause->v.t)) { - if (const auto *expr = Fortran::semantics::GetExpr(*chunkExpr)) { - scheduleChunkClauseOperand = - fir::getBase(converter.genExprValue(*expr, stmtCtx)); - } - } - } else if (const auto &ifClause = - std::get_if(&clause.u)) { - ifClauseOperand = - getIfClauseOperand(converter, stmtCtx, ifClause, clauseLocation); - } else if (const auto &reductionClause = - std::get_if( - &clause.u)) { - addReductionDecl(currentLocation, converter, reductionClause->v, - reductionVars, reductionDeclSymbols); - } else if (const auto &simdlenClause = - std::get_if( - &clause.u)) { - const auto *expr = Fortran::semantics::GetExpr(simdlenClause->v); - const std::optional simdlenVal = - Fortran::evaluate::ToInt64(*expr); - simdlenClauseOperand = firOpBuilder.getI64IntegerAttr(*simdlenVal); - } else if (const auto &safelenClause = - std::get_if( - &clause.u)) { - const auto *expr = Fortran::semantics::GetExpr(safelenClause->v); - const std::optional safelenVal = - Fortran::evaluate::ToInt64(*expr); - safelenClauseOperand = firOpBuilder.getI64IntegerAttr(*safelenVal); - } - } + ClauseProcessor cp(converter, loopOpClauseList); + cp.processCollapse(currentLocation, eval, lowerBound, upperBound, step, iv, + loopVarTypeSize); + cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand); + cp.processIf(stmtCtx, ifClauseOperand); + cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols); + cp.processSimdlen(simdlenClauseOperand); + cp.processSafelen(safelenClauseOperand); // The types of lower bound, upper bound, and step are converted into the // type of the loop variable if necessary. @@ -1719,16 +2309,15 @@ // 2.9.3.1 SIMD construct // TODO: Support all the clauses if (llvm::omp::OMPD_simd == ompDirective) { - TypeRange resultType; + mlir::TypeRange resultType; auto simdLoopOp = firOpBuilder.create( currentLocation, resultType, lowerBound, upperBound, step, alignedVars, - nullptr, ifClauseOperand, nontemporalVars, - orderClauseOperand.dyn_cast_or_null(), - simdlenClauseOperand, safelenClauseOperand, + /*alignment_values=*/nullptr, ifClauseOperand, nontemporalVars, + orderClauseOperand, simdlenClauseOperand, safelenClauseOperand, /*inclusive=*/firOpBuilder.getUnitAttr()); - createBodyOfOp(simdLoopOp, converter, currentLocation, - eval, &loopOpClauseList, iv, - /*outer=*/false, &dsp); + createBodyOfOp( + simdLoopOp, converter, currentLocation, eval, &loopOpClauseList, iv, + /*outer=*/false, &dsp); return; } @@ -1742,66 +2331,21 @@ ? nullptr : mlir::ArrayAttr::get(firOpBuilder.getContext(), reductionDeclSymbols), - scheduleClauseOperand.dyn_cast_or_null(), - scheduleChunkClauseOperand, /*schedule_modifiers=*/nullptr, - /*simd_modifier=*/nullptr, - noWaitClauseOperand.dyn_cast_or_null(), - orderedClauseOperand.dyn_cast_or_null(), - orderClauseOperand.dyn_cast_or_null(), + scheduleValClauseOperand, scheduleChunkClauseOperand, + /*schedule_modifiers=*/nullptr, + /*simd_modifier=*/nullptr, nowaitClauseOperand, orderedClauseOperand, + orderClauseOperand, /*inclusive=*/firOpBuilder.getUnitAttr()); // Handle attribute based clauses. - for (const Fortran::parser::OmpClause &clause : loopOpClauseList.v) { - if (const auto &orderedClause = - std::get_if(&clause.u)) { - if (orderedClause->v.has_value()) { - const auto *expr = Fortran::semantics::GetExpr(orderedClause->v); - const std::optional orderedClauseValue = - Fortran::evaluate::ToInt64(*expr); - wsLoopOp.setOrderedValAttr( - firOpBuilder.getI64IntegerAttr(*orderedClauseValue)); - } else { - wsLoopOp.setOrderedValAttr(firOpBuilder.getI64IntegerAttr(0)); - } - } else if (const auto &scheduleClause = - std::get_if( - &clause.u)) { - mlir::MLIRContext *context = firOpBuilder.getContext(); - const auto &scheduleType = scheduleClause->v; - const auto &scheduleKind = - std::get( - scheduleType.t); - switch (scheduleKind) { - case Fortran::parser::OmpScheduleClause::ScheduleType::Static: - wsLoopOp.setScheduleValAttr(omp::ClauseScheduleKindAttr::get( - context, omp::ClauseScheduleKind::Static)); - break; - case Fortran::parser::OmpScheduleClause::ScheduleType::Dynamic: - wsLoopOp.setScheduleValAttr(omp::ClauseScheduleKindAttr::get( - context, omp::ClauseScheduleKind::Dynamic)); - break; - case Fortran::parser::OmpScheduleClause::ScheduleType::Guided: - wsLoopOp.setScheduleValAttr(omp::ClauseScheduleKindAttr::get( - context, omp::ClauseScheduleKind::Guided)); - break; - case Fortran::parser::OmpScheduleClause::ScheduleType::Auto: - wsLoopOp.setScheduleValAttr(omp::ClauseScheduleKindAttr::get( - context, omp::ClauseScheduleKind::Auto)); - break; - case Fortran::parser::OmpScheduleClause::ScheduleType::Runtime: - wsLoopOp.setScheduleValAttr(omp::ClauseScheduleKindAttr::get( - context, omp::ClauseScheduleKind::Runtime)); - break; - } - mlir::omp::ScheduleModifier scheduleModifier = - getScheduleModifier(scheduleClause->v); - if (scheduleModifier != mlir::omp::ScheduleModifier::none) - wsLoopOp.setScheduleModifierAttr( - omp::ScheduleModifierAttr::get(context, scheduleModifier)); - if (getSIMDModifier(scheduleClause->v) != - mlir::omp::ScheduleModifier::none) - wsLoopOp.setSimdModifierAttr(firOpBuilder.getUnitAttr()); - } + if (cp.processOrdered(orderedClauseOperand)) + wsLoopOp.setOrderedValAttr(orderedClauseOperand); + + if (cp.processSchedule(scheduleValClauseOperand, scheduleModClauseOperand, + scheduleSimdClauseOperand)) { + wsLoopOp.setScheduleValAttr(scheduleValClauseOperand); + wsLoopOp.setScheduleModifierAttr(scheduleModClauseOperand); + wsLoopOp.setSimdModifierAttr(scheduleSimdClauseOperand); } // In FORTRAN `nowait` clause occur at the end of `omp do` directive. // i.e @@ -1813,13 +2357,14 @@ loopConstruct.t)) { const auto &clauseList = std::get((*endClauseList).t); - for (const Fortran::parser::OmpClause &clause : clauseList.v) - if (std::get_if(&clause.u)) - wsLoopOp.setNowaitAttr(firOpBuilder.getUnitAttr()); + if (ClauseProcessor(converter, clauseList) + .processNowait(nowaitClauseOperand)) + wsLoopOp.setNowaitAttr(nowaitClauseOperand); } - createBodyOfOp(wsLoopOp, converter, currentLocation, eval, - &loopOpClauseList, iv, /*outer=*/false, &dsp); + createBodyOfOp(wsLoopOp, converter, currentLocation, + eval, &loopOpClauseList, iv, + /*outer=*/false, &dsp); } static void @@ -1840,140 +2385,60 @@ mlir::Value ifClauseOperand, numThreadsClauseOperand, finalClauseOperand, priorityClauseOperand; mlir::omp::ClauseProcBindKindAttr procBindKindAttr; - SmallVector allocateOperands, allocatorOperands, dependOperands, - reductionVars; - SmallVector dependTypeOperands, reductionDeclSymbols; + llvm::SmallVector allocateOperands, allocatorOperands, + dependOperands, reductionVars; + llvm::SmallVector dependTypeOperands, reductionDeclSymbols; mlir::UnitAttr nowaitAttr, untiedAttr, mergeableAttr; const auto &opClauseList = std::get(beginBlockDirective.t); - for (const auto &clause : opClauseList.v) { - mlir::Location clauseLocation = converter.genLocation(clause.source); - if (const auto &ifClause = - std::get_if(&clause.u)) { - ifClauseOperand = - getIfClauseOperand(converter, stmtCtx, ifClause, clauseLocation); - } else if (const auto &numThreadsClause = - std::get_if( - &clause.u)) { - // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`. - numThreadsClauseOperand = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); - } else if (const auto &procBindClause = - std::get_if( - &clause.u)) { - procBindKindAttr = genProcBindKindAttr(firOpBuilder, procBindClause); - } else if (const auto &allocateClause = - std::get_if( - &clause.u)) { - genAllocateClause(converter, allocateClause->v, allocatorOperands, - allocateOperands); - } else if (std::get_if(&clause.u) || - std::get_if( - &clause.u) || - std::get_if(&clause.u)) { - // Privatisation and copyin clauses are handled elsewhere. - continue; - } else if (std::get_if(&clause.u)) { - // Shared is the default behavior in the IR, so no handling is required. - continue; - } else if (const auto &defaultClause = - std::get_if( - &clause.u)) { - if ((defaultClause->v.v == - Fortran::parser::OmpDefaultClause::Type::Shared) || - (defaultClause->v.v == - Fortran::parser::OmpDefaultClause::Type::None)) { - // Default clause with shared or none do not require any handling since - // Shared is the default behavior in the IR and None is only required - // for semantic checks. - continue; - } - } else if (std::get_if(&clause.u)) { - // Nothing needs to be done for threads clause. - continue; - } else if (std::get_if(&clause.u)) { - // Map clause is exclusive to Target Data directives. It is handled - // as part of the TargetOp creation. - continue; - } else if (std::get_if( - &clause.u)) { - // UseDevicePtr clause is exclusive to Target Data directives. It is - // handled as part of the TargetOp creation. - continue; - } else if (std::get_if( - &clause.u)) { - // UseDeviceAddr clause is exclusive to Target Data directives. It is - // handled as part of the TargetOp creation. - continue; - } else if (std::get_if( - &clause.u)) { - // Handled as part of TargetOp creation. - continue; - } else if (const auto &finalClause = - std::get_if(&clause.u)) { - mlir::Value finalVal = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(finalClause->v), stmtCtx)); - finalClauseOperand = firOpBuilder.createConvert( - currentLocation, firOpBuilder.getI1Type(), finalVal); - } else if (std::get_if(&clause.u)) { - untiedAttr = firOpBuilder.getUnitAttr(); - } else if (std::get_if(&clause.u)) { - mergeableAttr = firOpBuilder.getUnitAttr(); - } else if (const auto &priorityClause = - std::get_if( - &clause.u)) { - priorityClauseOperand = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(priorityClause->v), stmtCtx)); - } else if (const auto &reductionClause = - std::get_if( - &clause.u)) { - addReductionDecl(currentLocation, converter, reductionClause->v, - reductionVars, reductionDeclSymbols); - } else if (const auto &dependClause = - std::get_if(&clause.u)) { - const std::list &depVal = - std::get>( - std::get( - dependClause->v.u) - .t); - omp::ClauseTaskDependAttr dependTypeOperand = - genDependKindAttr(firOpBuilder, dependClause); - dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(), - dependTypeOperand); - for (const Fortran::parser::Designator &ompObject : depVal) { - Fortran::semantics::Symbol *sym = nullptr; - std::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::DataRef &designator) { - if (const Fortran::parser::Name *name = - std::get_if(&designator.u)) { - sym = name->symbol; - } else if (std::get_if>( - &designator.u)) { - TODO(converter.getCurrentLocation(), - "array sections not supported for task depend"); - } - }, - [&](const Fortran::parser::Substring &designator) { - TODO(converter.getCurrentLocation(), - "substring not supported for task depend"); - }}, - (ompObject).u); - const mlir::Value variable = converter.getSymbolAddress(*sym); - dependOperands.push_back(((variable))); - } - } else { + ClauseProcessor cp(converter, opClauseList); + cp.processIf(stmtCtx, ifClauseOperand); + cp.processNumThreads(stmtCtx, numThreadsClauseOperand); + cp.processProcBind(procBindKindAttr); + cp.processAllocate(allocatorOperands, allocateOperands); + cp.processDefault(); + cp.processFinal(stmtCtx, finalClauseOperand); + cp.processUntied(untiedAttr); + cp.processMergeable(mergeableAttr); + cp.processPriority(stmtCtx, priorityClauseOperand); + cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols); + cp.processDepend(dependTypeOperands, dependOperands); + + for (const Fortran::parser::OmpClause &clause : opClauseList.v) { + if (!std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + // Privatisation and copyin clauses are handled elsewhere. + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + // Shared is the default behavior in the IR, so no handling is required. + !std::get_if(&clause.u) && + // Nothing needs to be done for threads clause. + !std::get_if(&clause.u) && + // Map, UseDevicePtr, UseDeviceAddr and ThreadLimit clauses are + // exclusive to Target directives. They are handled as part of the + // TargetOp creation. + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u)) { TODO(converter.getCurrentLocation(), "OpenMP Block construct clause"); } } - for (const auto &clause : - std::get(endBlockDirective.t).v) { - if (std::get_if(&clause.u)) - nowaitAttr = firOpBuilder.getUnitAttr(); - } + ClauseProcessor(converter, + std::get(endBlockDirective.t)) + .processNowait(nowaitAttr); if (blockDirective.v == llvm::omp::OMPD_parallel) { // Create and insert the operation. @@ -1985,26 +2450,27 @@ : mlir::ArrayAttr::get(firOpBuilder.getContext(), reductionDeclSymbols), procBindKindAttr); - createBodyOfOp(parallelOp, converter, currentLocation, - eval, &opClauseList); + createBodyOfOp(parallelOp, converter, + currentLocation, eval, &opClauseList); } else if (blockDirective.v == llvm::omp::OMPD_master) { auto masterOp = firOpBuilder.create(currentLocation, argTy); - createBodyOfOp(masterOp, converter, currentLocation, eval); + createBodyOfOp(masterOp, converter, currentLocation, + eval); } else if (blockDirective.v == llvm::omp::OMPD_single) { auto singleOp = firOpBuilder.create( currentLocation, allocateOperands, allocatorOperands, nowaitAttr); - createBodyOfOp(singleOp, converter, currentLocation, eval, - &opClauseList); + createBodyOfOp(singleOp, converter, currentLocation, + eval, &opClauseList); } else if (blockDirective.v == llvm::omp::OMPD_ordered) { auto orderedOp = firOpBuilder.create( currentLocation, /*simd=*/false); - createBodyOfOp(orderedOp, converter, currentLocation, - eval); + createBodyOfOp(orderedOp, converter, + currentLocation, eval); } else if (blockDirective.v == llvm::omp::OMPD_task) { auto taskOp = firOpBuilder.create( currentLocation, ifClauseOperand, finalClauseOperand, untiedAttr, - mergeableAttr, /*in_reduction_vars=*/ValueRange(), + mergeableAttr, /*in_reduction_vars=*/mlir::ValueRange(), /*in_reductions=*/nullptr, priorityClauseOperand, dependTypeOperands.empty() ? nullptr @@ -2015,7 +2481,7 @@ } else if (blockDirective.v == llvm::omp::OMPD_taskgroup) { // TODO: Add task_reduction support auto taskGroupOp = firOpBuilder.create( - currentLocation, /*task_reduction_vars=*/ValueRange(), + currentLocation, /*task_reduction_vars=*/mlir::ValueRange(), /*task_reductions=*/nullptr, allocateOperands, allocatorOperands); createBodyOfOp(taskGroupOp, converter, currentLocation, eval, &opClauseList); @@ -2036,6 +2502,7 @@ const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::Location currentLocation = converter.getCurrentLocation(); + mlir::IntegerAttr hintClauseOp; std::string name; const Fortran::parser::OmpCriticalDirective &cd = std::get(criticalConstruct.t); @@ -2044,42 +2511,35 @@ std::get>(cd.t).value().ToString(); } - uint64_t hint = 0; const auto &clauseList = std::get(cd.t); - for (const Fortran::parser::OmpClause &clause : clauseList.v) - if (auto hintClause = - std::get_if(&clause.u)) { - const auto *expr = Fortran::semantics::GetExpr(hintClause->v); - hint = *Fortran::evaluate::ToInt64(*expr); - break; - } + ClauseProcessor(converter, clauseList).processHint(hintClauseOp); mlir::omp::CriticalOp criticalOp = [&]() { if (name.empty()) { - return firOpBuilder.create(currentLocation, - FlatSymbolRefAttr()); - } else { - mlir::ModuleOp module = firOpBuilder.getModule(); - mlir::OpBuilder modBuilder(module.getBodyRegion()); - auto global = module.lookupSymbol(name); - if (!global) - global = modBuilder.create( - currentLocation, name, hint); return firOpBuilder.create( - currentLocation, mlir::FlatSymbolRefAttr::get( - firOpBuilder.getContext(), global.getSymName())); + currentLocation, mlir::FlatSymbolRefAttr()); } + mlir::ModuleOp module = firOpBuilder.getModule(); + mlir::OpBuilder modBuilder(module.getBodyRegion()); + auto global = module.lookupSymbol(name); + if (!global) + global = modBuilder.create( + currentLocation, + mlir::StringAttr::get(firOpBuilder.getContext(), name), hintClauseOp); + return firOpBuilder.create( + currentLocation, mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(), + global.getSymName())); }(); - createBodyOfOp(criticalOp, converter, currentLocation, eval); + createBodyOfOp(criticalOp, converter, currentLocation, + eval); } static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPSectionConstruct §ionConstruct) { - - auto &firOpBuilder = converter.getFirOpBuilder(); - auto currentLocation = converter.getCurrentLocation(); + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); const Fortran::parser::OpenMPConstruct *parentOmpConstruct = eval.parentConstruct->getIf(); assert(parentOmpConstruct && @@ -2098,51 +2558,37 @@ // all privatization is done within `omp.section` operations. mlir::omp::SectionOp sectionOp = firOpBuilder.create(currentLocation); - createBodyOfOp(sectionOp, converter, currentLocation, eval, - §ionsClauseList); + createBodyOfOp(sectionOp, converter, currentLocation, + eval, §ionsClauseList); } static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) { - auto &firOpBuilder = converter.getFirOpBuilder(); - auto currentLocation = converter.getCurrentLocation(); - SmallVector reductionVars, allocateOperands, allocatorOperands; - mlir::UnitAttr noWaitClauseOperand; - const auto §ionsClauseList = std::get( - std::get(sectionsConstruct.t) - .t); - for (const Fortran::parser::OmpClause &clause : sectionsClauseList.v) { - - // Reduction Clause - if (std::get_if(&clause.u)) { - TODO(currentLocation, "OMPC_Reduction"); - - // Allocate clause - } else if (const auto &allocateClause = - std::get_if( - &clause.u)) { - genAllocateClause(converter, allocateClause->v, allocatorOperands, - allocateOperands); - } - } - const auto &endSectionsClauseList = + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); + llvm::SmallVector reductionVars, allocateOperands, + allocatorOperands; + mlir::UnitAttr nowaitClauseOperand; + const auto &beginSectionsDirective = + std::get(sectionsConstruct.t); + const auto §ionsClauseList = + std::get(beginSectionsDirective.t); + + ClauseProcessor cp(converter, sectionsClauseList); + cp.processSectionsReduction(currentLocation); + cp.processAllocate(allocatorOperands, allocateOperands); + + const auto &endSectionsDirective = std::get(sectionsConstruct.t); - const auto &clauseList = - std::get(endSectionsClauseList.t); - for (const auto &clause : clauseList.v) { - // Nowait clause - if (std::get_if(&clause.u)) { - noWaitClauseOperand = firOpBuilder.getUnitAttr(); - } - } + const auto &endSectionsClauseList = + std::get(endSectionsDirective.t); + ClauseProcessor(converter, endSectionsClauseList) + .processNowait(nowaitClauseOperand); llvm::omp::Directive dir = - std::get( - std::get( - sectionsConstruct.t) - .t) + std::get(beginSectionsDirective.t) .v; // Parallel Sections Construct @@ -2152,7 +2598,7 @@ std::get( sectionsConstruct.t)); auto sectionsOp = firOpBuilder.create( - currentLocation, /*reduction_vars*/ ValueRange(), + currentLocation, /*reduction_vars*/ mlir::ValueRange(), /*reductions=*/nullptr, allocateOperands, allocatorOperands, /*nowait=*/nullptr); createBodyOfOp(sectionsOp, converter, currentLocation, eval); @@ -2160,10 +2606,10 @@ // Sections Construct } else if (dir == llvm::omp::Directive::OMPD_sections) { auto sectionsOp = firOpBuilder.create( - currentLocation, reductionVars, /*reductions = */ nullptr, - allocateOperands, allocatorOperands, noWaitClauseOperand); - createBodyOfOp(sectionsOp, converter, currentLocation, - eval); + currentLocation, reductionVars, /*reductions=*/nullptr, + allocateOperands, allocatorOperands, nowaitClauseOperand); + createBodyOfOp(sectionsOp, converter, + currentLocation, eval); } } @@ -2204,34 +2650,39 @@ const Fortran::parser::OmpAtomicClauseList &clauseList, mlir::IntegerAttr &hint, mlir::omp::ClauseMemoryOrderKindAttr &memoryOrder) { - auto &firOpBuilder = converter.getFirOpBuilder(); - for (const auto &clause : clauseList.v) { - if (auto ompClause = std::get_if(&clause.u)) { - if (auto hintClause = + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + for (const Fortran::parser::OmpAtomicClause &clause : clauseList.v) { + if (const auto *ompClause = + std::get_if(&clause.u)) { + if (const auto *hintClause = std::get_if(&ompClause->u)) { const auto *expr = Fortran::semantics::GetExpr(hintClause->v); uint64_t hintExprValue = *Fortran::evaluate::ToInt64(*expr); hint = firOpBuilder.getI64IntegerAttr(hintExprValue); } - } else if (auto ompMemoryOrderClause = + } else if (const auto *ompMemoryOrderClause = std::get_if( &clause.u)) { if (std::get_if( &ompMemoryOrderClause->v.u)) { memoryOrder = mlir::omp::ClauseMemoryOrderKindAttr::get( - firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Acquire); + firOpBuilder.getContext(), + mlir::omp::ClauseMemoryOrderKind::Acquire); } else if (std::get_if( &ompMemoryOrderClause->v.u)) { memoryOrder = mlir::omp::ClauseMemoryOrderKindAttr::get( - firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Relaxed); + firOpBuilder.getContext(), + mlir::omp::ClauseMemoryOrderKind::Relaxed); } else if (std::get_if( &ompMemoryOrderClause->v.u)) { memoryOrder = mlir::omp::ClauseMemoryOrderKindAttr::get( - firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Seq_cst); + firOpBuilder.getContext(), + mlir::omp::ClauseMemoryOrderKind::Seq_cst); } else if (std::get_if( &ompMemoryOrderClause->v.u)) { memoryOrder = mlir::omp::ClauseMemoryOrderKindAttr::get( - firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Release); + firOpBuilder.getContext(), + mlir::omp::ClauseMemoryOrderKind::Release); } } } @@ -2239,65 +2690,65 @@ static void genOmpAtomicCaptureStatement( Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, mlir::Value from_address, - mlir::Value to_address, + Fortran::lower::pft::Evaluation &eval, mlir::Value fromAddress, + mlir::Value toAddress, const Fortran::parser::OmpAtomicClauseList *leftHandClauseList, const Fortran::parser::OmpAtomicClauseList *rightHandClauseList, mlir::Type elementType) { // Generate `omp.atomic.read` operation for atomic assigment statements - auto &firOpBuilder = converter.getFirOpBuilder(); - auto currentLocation = converter.getCurrentLocation(); + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); // If no hint clause is specified, the effect is as if // hint(omp_sync_hint_none) had been specified. mlir::IntegerAttr hint = nullptr; - mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; + mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr; if (leftHandClauseList) genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint, - memory_order); + memoryOrder); if (rightHandClauseList) genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint, - memory_order); + memoryOrder); firOpBuilder.create( - currentLocation, from_address, to_address, - mlir::TypeAttr::get(elementType), hint, memory_order); + currentLocation, fromAddress, toAddress, mlir::TypeAttr::get(elementType), + hint, memoryOrder); } static void genOmpAtomicWriteStatement( Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, mlir::Value lhs_addr, - mlir::Value rhs_expr, + Fortran::lower::pft::Evaluation &eval, mlir::Value lhsAddr, + mlir::Value rhsExpr, const Fortran::parser::OmpAtomicClauseList *leftHandClauseList, const Fortran::parser::OmpAtomicClauseList *rightHandClauseList, mlir::Value *evaluatedExprValue = nullptr) { // Generate `omp.atomic.write` operation for atomic assignment statements - auto &firOpBuilder = converter.getFirOpBuilder(); - auto currentLocation = converter.getCurrentLocation(); + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); // If no hint clause is specified, the effect is as if // hint(omp_sync_hint_none) had been specified. mlir::IntegerAttr hint = nullptr; - mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; + mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr; if (leftHandClauseList) genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint, - memory_order); + memoryOrder); if (rightHandClauseList) genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint, - memory_order); - firOpBuilder.create(currentLocation, lhs_addr, - rhs_expr, hint, memory_order); + memoryOrder); + firOpBuilder.create(currentLocation, lhsAddr, + rhsExpr, hint, memoryOrder); } static void genOmpAtomicUpdateStatement( Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, mlir::Value lhs_addr, + Fortran::lower::pft::Evaluation &eval, mlir::Value lhsAddr, mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable, const Fortran::parser::Expr &assignmentStmtExpr, const Fortran::parser::OmpAtomicClauseList *leftHandClauseList, const Fortran::parser::OmpAtomicClauseList *rightHandClauseList) { // Generate `omp.atomic.update` operation for atomic assignment statements - auto &firOpBuilder = converter.getFirOpBuilder(); - auto currentLocation = converter.getCurrentLocation(); + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); // If no hint clause is specified, the effect is as if // hint(omp_sync_hint_none) had been specified. @@ -2310,22 +2761,22 @@ genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint, memoryOrder); auto atomicUpdateOp = firOpBuilder.create( - currentLocation, lhs_addr, hint, memoryOrder); + currentLocation, lhsAddr, hint, memoryOrder); //// Generate body of Atomic Update operation // If an argument for the region is provided then create the block with that // argument. Also update the symbol's address with the argument mlir value. - SmallVector varTys = {varType}; - SmallVector locs = {currentLocation}; + llvm::SmallVector varTys = {varType}; + llvm::SmallVector locs = {currentLocation}; firOpBuilder.createBlock(&atomicUpdateOp.getRegion(), {}, varTys, locs); mlir::Value val = fir::getBase(atomicUpdateOp.getRegion().front().getArgument(0)); - auto varDesignator = + const auto *varDesignator = std::get_if>( &assignmentStmtVariable.u); assert(varDesignator && "Variable designator for atomic update assignment " "statement does not exist"); - const auto *name = + const Fortran::parser::Name *name = Fortran::semantics::getDesignatorNameIfDataRef(varDesignator->value()); if (!name) TODO(converter.getCurrentLocation(), @@ -2339,10 +2790,10 @@ firOpBuilder.setInsertionPointToEnd(&block); Fortran::lower::StatementContext stmtCtx; - mlir::Value rhs_expr = fir::getBase(converter.genExprValue( + mlir::Value rhsExpr = fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx)); mlir::Value convertResult = - firOpBuilder.createConvert(currentLocation, varType, rhs_expr); + firOpBuilder.createConvert(currentLocation, varType, rhsExpr); // Insert the terminator: YieldOp. firOpBuilder.create(currentLocation, convertResult); // Reset the insert point to before the terminator. @@ -2363,12 +2814,11 @@ const Fortran::evaluate::Assignment &assign = *stmt.typedAssignment->v; Fortran::lower::StatementContext stmtCtx; // Get the value and address of atomic write operands. - mlir::Value rhs_expr = + mlir::Value rhsExpr = fir::getBase(converter.genExprValue(assign.rhs, stmtCtx)); - - mlir::Value lhs_addr = + mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(assign.lhs, stmtCtx)); - genOmpAtomicWriteStatement(converter, eval, lhs_addr, rhs_expr, + genOmpAtomicWriteStatement(converter, eval, lhsAddr, rhsExpr, &leftHandClauseList, &rightHandClauseList); } @@ -2412,14 +2862,14 @@ std::get<3>(atomicUpdate.t).statement.t); Fortran::lower::StatementContext stmtCtx; - mlir::Value lhs_addr = fir::getBase(converter.genExprAddr( + mlir::Value lhsAddr = fir::getBase(converter.genExprAddr( *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); mlir::Type varType = fir::getBase( converter.genExprValue( *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)) .getType(); - genOmpAtomicUpdateStatement(converter, eval, lhs_addr, varType, + genOmpAtomicUpdateStatement(converter, eval, lhsAddr, varType, assignmentStmtVariable, assignmentStmtExpr, &leftHandClauseList, &rightHandClauseList); } @@ -2438,7 +2888,7 @@ atomicConstruct.t) .statement.t); Fortran::lower::StatementContext stmtCtx; - mlir::Value lhs_addr = fir::getBase(converter.genExprAddr( + mlir::Value lhsAddr = fir::getBase(converter.genExprAddr( *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); mlir::Type varType = fir::getBase( @@ -2447,7 +2897,7 @@ .getType(); // If atomic-clause is not present on the construct, the behaviour is as if // the update clause is specified - genOmpAtomicUpdateStatement(converter, eval, lhs_addr, varType, + genOmpAtomicUpdateStatement(converter, eval, lhsAddr, varType, assignmentStmtVariable, assignmentStmtExpr, &atomicClauseList, nullptr); } @@ -2460,15 +2910,15 @@ mlir::Location currentLocation = converter.getCurrentLocation(); mlir::IntegerAttr hint = nullptr; - mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; + mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr; const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = std::get<2>(atomicCapture.t); const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = std::get<0>(atomicCapture.t); genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint, - memory_order); + memoryOrder); genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint, - memory_order); + memoryOrder); const Fortran::parser::AssignmentStmt &stmt1 = std::get<3>(atomicCapture.t).v.statement; @@ -2519,7 +2969,7 @@ .getType(); auto atomicCaptureOp = firOpBuilder.create( - currentLocation, hint, memory_order); + currentLocation, hint, memoryOrder); firOpBuilder.createBlock(&atomicCaptureOp.getRegion()); mlir::Block &block = atomicCaptureOp.getRegion().back(); firOpBuilder.setInsertionPointToStart(&block); @@ -2550,55 +3000,143 @@ /*leftHandClauseList=*/nullptr, /*rightHandClauseList=*/nullptr); } - } else { - // Atomic capture construct is of the form [update-stmt, capture-stmt] - firOpBuilder.setInsertionPointToEnd(&block); - const Fortran::semantics::SomeExpr &fromExpr = - *Fortran::semantics::GetExpr(stmt2Expr); - elementType = converter.genType(fromExpr); - genOmpAtomicCaptureStatement(converter, eval, stmt1LHSArg, stmt2LHSArg, - /*leftHandClauseList=*/nullptr, - /*rightHandClauseList=*/nullptr, elementType); - firOpBuilder.setInsertionPointToStart(&block); - genOmpAtomicUpdateStatement(converter, eval, stmt1LHSArg, stmt1VarType, - stmt1Var, stmt1Expr, - /*leftHandClauseList=*/nullptr, - /*rightHandClauseList=*/nullptr); + } else { + // Atomic capture construct is of the form [update-stmt, capture-stmt] + firOpBuilder.setInsertionPointToEnd(&block); + const Fortran::semantics::SomeExpr &fromExpr = + *Fortran::semantics::GetExpr(stmt2Expr); + elementType = converter.genType(fromExpr); + genOmpAtomicCaptureStatement(converter, eval, stmt1LHSArg, stmt2LHSArg, + /*leftHandClauseList=*/nullptr, + /*rightHandClauseList=*/nullptr, elementType); + firOpBuilder.setInsertionPointToStart(&block); + genOmpAtomicUpdateStatement(converter, eval, stmt1LHSArg, stmt1VarType, + stmt1Var, stmt1Expr, + /*leftHandClauseList=*/nullptr, + /*rightHandClauseList=*/nullptr); + } + firOpBuilder.setInsertionPointToEnd(&block); + firOpBuilder.create(currentLocation); + firOpBuilder.setInsertionPointToStart(&block); +} + +static void +genOMP(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) { + std::visit(Fortran::common::visitors{ + [&](const Fortran::parser::OmpAtomicRead &atomicRead) { + genOmpAtomicRead(converter, eval, atomicRead); + }, + [&](const Fortran::parser::OmpAtomicWrite &atomicWrite) { + genOmpAtomicWrite(converter, eval, atomicWrite); + }, + [&](const Fortran::parser::OmpAtomic &atomicConstruct) { + genOmpAtomic(converter, eval, atomicConstruct); + }, + [&](const Fortran::parser::OmpAtomicUpdate &atomicUpdate) { + genOmpAtomicUpdate(converter, eval, atomicUpdate); + }, + [&](const Fortran::parser::OmpAtomicCapture &atomicCapture) { + genOmpAtomicCapture(converter, eval, atomicCapture); + }, + }, + atomicConstruct.u); +} + +static void genOMP(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + llvm::SmallVector symbolAndClause; + mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); + + // The default capture type + mlir::omp::DeclareTargetDeviceType deviceType = + mlir::omp::DeclareTargetDeviceType::any; + const auto &spec = std::get( + declareTargetConstruct.t); + if (const auto *objectList{ + Fortran::parser::Unwrap(spec.u)}) { + // Case: declare target(func, var1, var2) + gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to, + symbolAndClause); + } else if (const auto *clauseList{ + Fortran::parser::Unwrap( + spec.u)}) { + if (clauseList->v.empty()) { + // Case: declare target, implicit capture of function + symbolAndClause.emplace_back( + mlir::omp::DeclareTargetCaptureClause::to, + eval.getOwningProcedure()->getSubprogramSymbol()); + } + + ClauseProcessor cp(converter, *clauseList); + cp.processTo(symbolAndClause); + cp.processLink(symbolAndClause); + cp.processDeviceType(deviceType); + } + + for (const DeclareTargetCapturePair &symClause : symbolAndClause) { + mlir::Operation *op = mod.lookupSymbol( + converter.mangleName(std::get(symClause))); + // There's several cases this can currently be triggered and it could be + // one of the following: + // 1) Invalid argument passed to a declare target that currently isn't + // captured by a frontend semantic check + // 2) The symbol of a valid argument is not correctly updated by one of + // the prior passes, resulting in missing symbol information + // 3) It's a variable internal to a module or program, that is legal by + // Fortran OpenMP standards, but is currently unhandled as they do not + // appear in the symbol table as they are represented as allocas + if (!op) + TODO(converter.getCurrentLocation(), + "Missing symbol, possible case of currently unsupported use of " + "a program local variable in declare target or erroneous symbol " + "information "); + + auto declareTargetOp = + llvm::dyn_cast(op); + if (!declareTargetOp) + fir::emitFatalError( + converter.getCurrentLocation(), + "Attempt to apply declare target on unsupported operation"); + + // The function or global already has a declare target applied to it, very + // likely through implicit capture (usage in another declare target + // function/subroutine). It should be marked as any if it has been assigned + // both host and nohost, else we skip, as there is no change + if (declareTargetOp.isDeclareTarget()) { + if (declareTargetOp.getDeclareTargetDeviceType() != deviceType) + declareTargetOp.setDeclareTarget( + mlir::omp::DeclareTargetDeviceType::any, + std::get(symClause)); + continue; + } + + declareTargetOp.setDeclareTarget( + deviceType, std::get(symClause)); } - firOpBuilder.setInsertionPointToEnd(&block); - firOpBuilder.create(currentLocation); - firOpBuilder.setInsertionPointToStart(&block); } -static void -genOMP(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) { - std::visit(Fortran::common::visitors{ - [&](const Fortran::parser::OmpAtomicRead &atomicRead) { - genOmpAtomicRead(converter, eval, atomicRead); - }, - [&](const Fortran::parser::OmpAtomicWrite &atomicWrite) { - genOmpAtomicWrite(converter, eval, atomicWrite); - }, - [&](const Fortran::parser::OmpAtomic &atomicConstruct) { - genOmpAtomic(converter, eval, atomicConstruct); - }, - [&](const Fortran::parser::OmpAtomicUpdate &atomicUpdate) { - genOmpAtomicUpdate(converter, eval, atomicUpdate); - }, - [&](const Fortran::parser::OmpAtomicCapture &atomicCapture) { - genOmpAtomicCapture(converter, eval, atomicCapture); - }, - }, - atomicConstruct.u); +//===----------------------------------------------------------------------===// +// Public functions +//===----------------------------------------------------------------------===// + +void Fortran::lower::genOpenMPTerminator(fir::FirOpBuilder &builder, + mlir::Operation *op, + mlir::Location loc) { + if (mlir::isa(op)) + builder.create(loc); + else + builder.create(loc); } void Fortran::lower::genOpenMPConstruct( Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPConstruct &ompConstruct) { - std::visit( common::visitors{ [&](const Fortran::parser::OpenMPStandaloneConstruct @@ -2641,37 +3179,51 @@ ompConstruct.u); } -fir::GlobalOp globalInitialization(Fortran::lower::AbstractConverter &converter, - fir::FirOpBuilder &firOpBuilder, - const Fortran::semantics::Symbol &sym, - const Fortran::lower::pft::Variable &var, - mlir::Location currentLocation) { - mlir::Type ty = converter.genType(sym); - std::string globalName = converter.mangleName(sym); - mlir::StringAttr linkage = firOpBuilder.createInternalLinkage(); - fir::GlobalOp global = - firOpBuilder.createGlobal(currentLocation, ty, globalName, linkage); +void Fortran::lower::genOpenMPDeclarativeConstruct( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclarativeConstruct &ompDeclConstruct) { + std::visit( + common::visitors{ + [&](const Fortran::parser::OpenMPDeclarativeAllocate + &declarativeAllocate) { + TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate"); + }, + [&](const Fortran::parser::OpenMPDeclareReductionConstruct + &declareReductionConstruct) { + TODO(converter.getCurrentLocation(), + "OpenMPDeclareReductionConstruct"); + }, + [&](const Fortran::parser::OpenMPDeclareSimdConstruct + &declareSimdConstruct) { + TODO(converter.getCurrentLocation(), "OpenMPDeclareSimdConstruct"); + }, + [&](const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + genOMP(converter, eval, declareTargetConstruct); + }, + [&](const Fortran::parser::OpenMPRequiresConstruct + &requiresConstruct) { + TODO(converter.getCurrentLocation(), "OpenMPRequiresConstruct"); + }, + [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) { + // The directive is lowered when instantiating the variable to + // support the case of threadprivate variable declared in module. + }, + }, + ompDeclConstruct.u); +} - // Create default initialization for non-character scalar. - if (Fortran::semantics::IsAllocatableOrPointer(sym)) { - mlir::Type baseAddrType = ty.dyn_cast().getEleTy(); - Fortran::lower::createGlobalInitialization( - firOpBuilder, global, [&](fir::FirOpBuilder &b) { - mlir::Value nullAddr = - b.createNullConstant(currentLocation, baseAddrType); - mlir::Value box = - b.create(currentLocation, ty, nullAddr); - b.create(currentLocation, box); - }); - } else { - Fortran::lower::createGlobalInitialization( - firOpBuilder, global, [&](fir::FirOpBuilder &b) { - mlir::Value undef = b.create(currentLocation, ty); - b.create(currentLocation, undef); - }); +int64_t Fortran::lower::getCollapseValue( + const Fortran::parser::OmpClauseList &clauseList) { + for (const Fortran::parser::OmpClause &clause : clauseList.v) { + if (const auto &collapseClause = + std::get_if(&clause.u)) { + const auto *expr = Fortran::semantics::GetExpr(collapseClause->v); + return Fortran::evaluate::ToInt64(*expr).value(); + } } - - return global; + return 1; } void Fortran::lower::genThreadprivateOp( @@ -2745,181 +3297,6 @@ } } -void handleDeclareTarget(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPDeclareTargetConstruct - &declareTargetConstruct) { - llvm::SmallVector, - 0> - symbolAndClause; - mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); - - auto findFuncAndVarSyms = [&](const Fortran::parser::OmpObjectList &objList, - mlir::omp::DeclareTargetCaptureClause clause) { - for (const Fortran::parser::OmpObject &ompObject : objList.v) { - Fortran::common::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::Designator &designator) { - if (const Fortran::parser::Name *name = - Fortran::semantics::getDesignatorNameIfDataRef( - designator)) { - symbolAndClause.push_back( - std::make_pair(clause, *name->symbol)); - } - }, - [&](const Fortran::parser::Name &name) { - symbolAndClause.push_back(std::make_pair(clause, *name.symbol)); - }}, - ompObject.u); - } - }; - - // The default capture type - Fortran::parser::OmpDeviceTypeClause::Type deviceType = - Fortran::parser::OmpDeviceTypeClause::Type::Any; - const auto &spec = std::get( - declareTargetConstruct.t); - if (const auto *objectList{ - Fortran::parser::Unwrap(spec.u)}) { - // Case: declare target(func, var1, var2) - findFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to); - } else if (const auto *clauseList{ - Fortran::parser::Unwrap( - spec.u)}) { - if (clauseList->v.empty()) { - // Case: declare target, implicit capture of function - symbolAndClause.push_back( - std::make_pair(mlir::omp::DeclareTargetCaptureClause::to, - eval.getOwningProcedure()->getSubprogramSymbol())); - } - - for (const Fortran::parser::OmpClause &clause : clauseList->v) { - if (const auto *toClause = - std::get_if(&clause.u)) { - // Case: declare target to(func, var1, var2)... - findFuncAndVarSyms(toClause->v, - mlir::omp::DeclareTargetCaptureClause::to); - } else if (const auto *linkClause = - std::get_if(&clause.u)) { - // Case: declare target link(var1, var2)... - findFuncAndVarSyms(linkClause->v, - mlir::omp::DeclareTargetCaptureClause::link); - } else if (const auto *deviceClause = - std::get_if( - &clause.u)) { - // Case: declare target ... device_type(any | host | nohost) - deviceType = deviceClause->v.v; - } - } - } - - for (std::pair - symClause : symbolAndClause) { - mlir::Operation *op = - mod.lookupSymbol(converter.mangleName(std::get<1>(symClause))); - // There's several cases this can currently be triggered and it could be - // one of the following: - // 1) Invalid argument passed to a declare target that currently isn't - // captured by a frontend semantic check - // 2) The symbol of a valid argument is not correctly updated by one of - // the prior passes, resulting in missing symbol information - // 3) It's a variable internal to a module or program, that is legal by - // Fortran OpenMP standards, but is currently unhandled as they do not - // appear in the symbol table as they are represented as allocas - if (!op) - TODO(converter.getCurrentLocation(), - "Missing symbol, possible case of currently unsupported use of " - "a program local variable in declare target or erroneous symbol " - "information "); - - auto declareTargetOp = dyn_cast(op); - if (!declareTargetOp) - fir::emitFatalError( - converter.getCurrentLocation(), - "Attempt to apply declare target on unsupported operation"); - - mlir::omp::DeclareTargetDeviceType newDeviceType; - switch (deviceType) { - case Fortran::parser::OmpDeviceTypeClause::Type::Nohost: - newDeviceType = mlir::omp::DeclareTargetDeviceType::nohost; - break; - case Fortran::parser::OmpDeviceTypeClause::Type::Host: - newDeviceType = mlir::omp::DeclareTargetDeviceType::host; - break; - case Fortran::parser::OmpDeviceTypeClause::Type::Any: - newDeviceType = mlir::omp::DeclareTargetDeviceType::any; - break; - } - - // The function or global already has a declare target applied to it, - // very likely through implicit capture (usage in another declare - // target function/subroutine). It should be marked as any if it has - // been assigned both host and nohost, else we skip, as there is no - // change - if (declareTargetOp.isDeclareTarget()) { - if (declareTargetOp.getDeclareTargetDeviceType() != newDeviceType) - declareTargetOp.setDeclareTarget( - mlir::omp::DeclareTargetDeviceType::any, std::get<0>(symClause)); - continue; - } - - declareTargetOp.setDeclareTarget(newDeviceType, std::get<0>(symClause)); - } -} - -void Fortran::lower::genOpenMPDeclarativeConstruct( - Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPDeclarativeConstruct &ompDeclConstruct) { - - std::visit( - common::visitors{ - [&](const Fortran::parser::OpenMPDeclarativeAllocate - &declarativeAllocate) { - TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate"); - }, - [&](const Fortran::parser::OpenMPDeclareReductionConstruct - &declareReductionConstruct) { - TODO(converter.getCurrentLocation(), - "OpenMPDeclareReductionConstruct"); - }, - [&](const Fortran::parser::OpenMPDeclareSimdConstruct - &declareSimdConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPDeclareSimdConstruct"); - }, - [&](const Fortran::parser::OpenMPDeclareTargetConstruct - &declareTargetConstruct) { - handleDeclareTarget(converter, eval, declareTargetConstruct); - }, - [&](const Fortran::parser::OpenMPRequiresConstruct - &requiresConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPRequiresConstruct"); - }, - [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) { - // The directive is lowered when instantiating the variable to - // support the case of threadprivate variable declared in module. - }, - }, - ompDeclConstruct.u); -} - -static mlir::Operation *getCompareFromReductionOp(mlir::Operation *reductionOp, - mlir::Value loadVal) { - for (auto reductionOperand : reductionOp->getOperands()) { - if (auto compareOp = reductionOperand.getDefiningOp()) { - if (compareOp->getOperand(0) == loadVal || - compareOp->getOperand(1) == loadVal) - assert((mlir::isa(compareOp) || - mlir::isa(compareOp)) && - "Expected comparison not found in reduction intrinsic"); - return compareOp; - } - } - return nullptr; -} - // Generate an OpenMP reduction operation. // TODO: Currently assumes it is either an integer addition/multiplication // reduction, or a logical and reduction. Generalize this for various reduction @@ -2932,14 +3309,14 @@ const Fortran::parser::OmpClauseList &clauseList) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - for (const auto &clause : clauseList.v) { + for (const Fortran::parser::OmpClause &clause : clauseList.v) { if (const auto &reductionClause = std::get_if(&clause.u)) { const auto &redOperator{std::get( reductionClause->v.t)}; const auto &objectList{ std::get(reductionClause->v.t)}; - if (auto reductionOp = + if (const auto *reductionOp = std::get_if(&redOperator.u)) { const auto &intrinsicOp{ std::get( @@ -2956,10 +3333,10 @@ default: continue; } - for (const auto &ompObject : objectList.v) { + for (const Fortran::parser::OmpObject &ompObject : objectList.v) { if (const auto *name{ Fortran::parser::Unwrap(ompObject)}) { - if (const auto *symbol{name->symbol}) { + if (const Fortran::semantics::Symbol * symbol{name->symbol}) { mlir::Value reductionVal = converter.getSymbolAddress(*symbol); mlir::Type reductionType = reductionVal.getType().cast().getEleTy(); @@ -2978,7 +3355,7 @@ updateReduction(reductionOp, firOpBuilder, loadVal, reductionVal, &convertOp); removeStoreOp(reductionOp, reductionVal); - } else if (auto reductionOp = + } else if (mlir::Operation *reductionOp = findReductionChain(loadVal, &reductionVal)) { updateReduction(reductionOp, firOpBuilder, loadVal, reductionVal); @@ -2988,7 +3365,7 @@ } } } - } else if (auto reductionIntrinsic = + } else if (const auto *reductionIntrinsic = std::get_if( &redOperator.u)) { if (const auto *name{Fortran::parser::Unwrap( @@ -2999,12 +3376,12 @@ (name->source != "iand")) { continue; } - for (const auto &ompObject : objectList.v) { + for (const Fortran::parser::OmpObject &ompObject : objectList.v) { if (const auto *name{Fortran::parser::Unwrap( ompObject)}) { - if (const auto *symbol{name->symbol}) { + if (const Fortran::semantics::Symbol * symbol{name->symbol}) { mlir::Value reductionVal = converter.getSymbolAddress(*symbol); - for (mlir::OpOperand &reductionValUse : + for (const mlir::OpOperand &reductionValUse : reductionVal.getUses()) { if (auto loadOp = mlir::dyn_cast( reductionValUse.getOwner())) { @@ -3044,10 +3421,10 @@ mlir::Operation *Fortran::lower::findReductionChain(mlir::Value loadVal, mlir::Value *reductionVal) { for (mlir::OpOperand &loadOperand : loadVal.getUses()) { - if (auto reductionOp = loadOperand.getOwner()) { + if (mlir::Operation *reductionOp = loadOperand.getOwner()) { if (auto convertOp = mlir::dyn_cast(reductionOp)) { for (mlir::OpOperand &convertOperand : convertOp.getRes().getUses()) { - if (auto reductionOp = convertOperand.getOwner()) + if (mlir::Operation *reductionOp = convertOperand.getOwner()) return reductionOp; } } @@ -3065,6 +3442,23 @@ return nullptr; } +// for a logical operator 'op' reduction X = X op Y +// This function returns the operation responsible for converting Y from +// fir.logical<4> to i1 +fir::ConvertOp +Fortran::lower::getConvertFromReductionOp(mlir::Operation *reductionOp, + mlir::Value loadVal) { + for (mlir::Value reductionOperand : reductionOp->getOperands()) { + if (auto convertOp = + mlir::dyn_cast(reductionOperand.getDefiningOp())) { + if (convertOp.getOperand() == loadVal) + continue; + return convertOp; + } + } + return nullptr; +} + void Fortran::lower::updateReduction(mlir::Operation *op, fir::FirOpBuilder &firOpBuilder, mlir::Value loadVal, @@ -3086,29 +3480,13 @@ firOpBuilder.restoreInsertionPoint(insertPtDel); } -// for a logical operator 'op' reduction X = X op Y -// This function returns the operation responsible for converting Y from -// fir.logical<4> to i1 -fir::ConvertOp -Fortran::lower::getConvertFromReductionOp(mlir::Operation *reductionOp, - mlir::Value loadVal) { - for (auto reductionOperand : reductionOp->getOperands()) { - if (auto convertOp = - mlir::dyn_cast(reductionOperand.getDefiningOp())) { - if (convertOp.getOperand() == loadVal) - continue; - return convertOp; - } - } - return nullptr; -} - void Fortran::lower::removeStoreOp(mlir::Operation *reductionOp, mlir::Value symVal) { - for (auto reductionOpUse : reductionOp->getUsers()) { + for (mlir::Operation *reductionOpUse : reductionOp->getUsers()) { if (auto convertReduction = mlir::dyn_cast(reductionOpUse)) { - for (auto convertReductionUse : convertReduction.getRes().getUsers()) { + for (mlir::Operation *convertReductionUse : + convertReduction.getRes().getUsers()) { if (auto storeOp = mlir::dyn_cast(convertReductionUse)) { if (storeOp.getMemref() == symVal) storeOp.erase();