diff --git a/flang/include/flang/Optimizer/Dialect/FIRDialect.h b/flang/include/flang/Optimizer/Dialect/FIRDialect.h --- a/flang/include/flang/Optimizer/Dialect/FIRDialect.h +++ b/flang/include/flang/Optimizer/Dialect/FIRDialect.h @@ -37,6 +37,7 @@ // clang-format off registry.insert(&designator.u)}; + return dataRef ? std::get_if(&dataRef->u) : nullptr; +} + +static void genObjectList(const Fortran::parser::AccObjectList &objectList, + Fortran::lower::AbstractConverter &converter, + std::int32_t &objectsCount, + SmallVector &operands) { + for (const auto &accObject : objectList.v) { + std::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::Designator &designator) { + if (const auto *name = getDesignatorNameIfDataRef(designator)) { + ++objectsCount; + const auto variable = converter.getSymbolAddress(*name->symbol); + operands.push_back(variable); + } + }, + [&](const Fortran::parser::Name &name) { + ++objectsCount; + const auto variable = converter.getSymbolAddress(*name.symbol); + operands.push_back(variable); + }}, + accObject.u); + } +} + +static void genACC(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenACCLoopConstruct &loopConstruct) { + + const auto &beginLoopDirective = + std::get(loopConstruct.t); + const auto &loopDirective = + std::get(beginLoopDirective.t); + + if (loopDirective.v == llvm::acc::ACCD_loop) { + auto &firOpBuilder = converter.getFirOpBuilder(); + auto currentLocation = converter.getCurrentLocation(); + llvm::ArrayRef argTy; + + // Add attribute extracted from clauses. + const auto &accClauseList = + std::get(beginLoopDirective.t); + + mlir::Value workerNum; + mlir::Value vectorLength; + mlir::Value gangNum; + mlir::Value gangStatic; + std::int32_t tileOperands = 0; + std::int32_t privateOperands = 0; + std::int32_t reductionOperands = 0; + std::int64_t executionMapping = mlir::acc::OpenACCExecMapping::NONE; + SmallVector operands; + + // Lower clauses values mapped to operands. + for (const auto &clause : accClauseList.v) { + if (const auto *gangClause = + std::get_if(&clause.u)) { + if (gangClause->v) { + const Fortran::parser::AccGangArgument &x = *gangClause->v; + if (const auto &gangNumValue = + std::get>( + x.t)) { + gangNum = converter.genExprValue( + *Fortran::semantics::GetExpr(gangNumValue.value())); + operands.push_back(gangNum); + } + if (const auto &gangStaticValue = + std::get>(x.t)) { + const auto &expr = + std::get>( + gangStaticValue.value().t); + if (expr) { + gangStatic = + converter.genExprValue(*Fortran::semantics::GetExpr(*expr)); + } else { + // * was passed as value and will be represented as a -1 constant + // integer. + gangStatic = firOpBuilder.createIntegerConstant( + currentLocation, firOpBuilder.getIntegerType(32), + /* STAR */ -1); + } + operands.push_back(gangStatic); + } + } + executionMapping |= mlir::acc::OpenACCExecMapping::GANG; + } else if (const auto *workerClause = + std::get_if( + &clause.u)) { + if (workerClause->v) { + workerNum = converter.genExprValue( + *Fortran::semantics::GetExpr(*workerClause->v)); + operands.push_back(workerNum); + } + executionMapping |= mlir::acc::OpenACCExecMapping::WORKER; + } else if (const auto *vectorClause = + std::get_if( + &clause.u)) { + if (vectorClause->v) { + vectorLength = converter.genExprValue( + *Fortran::semantics::GetExpr(*vectorClause->v)); + operands.push_back(vectorLength); + } + executionMapping |= mlir::acc::OpenACCExecMapping::VECTOR; + } else if (const auto *tileClause = + std::get_if(&clause.u)) { + const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v; + for (const auto &accTileExpr : accTileExprList.v) { + const auto &expr = + std::get>( + accTileExpr.t); + ++tileOperands; + if (expr) { + operands.push_back( + converter.genExprValue(*Fortran::semantics::GetExpr(*expr))); + } else { + // * was passed as value and will be represented as a -1 constant + // integer. + mlir::Value tileStar = firOpBuilder.createIntegerConstant( + currentLocation, firOpBuilder.getIntegerType(32), + /* STAR */ -1); + operands.push_back(tileStar); + } + } + } else if (const auto *privateClause = + std::get_if( + &clause.u)) { + const Fortran::parser::AccObjectList &accObjectList = privateClause->v; + genObjectList(accObjectList, converter, privateOperands, operands); + } + // Reduction clause is left out for the moment as the clause will probably + // end up having its own operation. + } + + auto loopOp = firOpBuilder.create(currentLocation, argTy, + operands); + + firOpBuilder.createBlock(&loopOp.getRegion()); + auto &block = loopOp.getRegion().back(); + firOpBuilder.setInsertionPointToStart(&block); + // ensure the block is well-formed. + firOpBuilder.create(currentLocation); + + loopOp.setAttr(mlir::acc::LoopOp::getOperandSegmentSizeAttr(), + firOpBuilder.getI32VectorAttr( + {gangNum ? 1 : 0, gangStatic ? 1 : 0, workerNum ? 1 : 0, + vectorLength ? 1 : 0, tileOperands, privateOperands, + reductionOperands})); + + loopOp.setAttr(mlir::acc::LoopOp::getExecutionMappingAttrName(), + firOpBuilder.getI64IntegerAttr(executionMapping)); + + // Lower clauses mapped to attributes + for (const auto &clause : accClauseList.v) { + if (const auto *collapseClause = + std::get_if(&clause.u)) { + const auto *expr = Fortran::semantics::GetExpr(collapseClause->v); + const auto collapseValue = Fortran::evaluate::ToInt64(*expr); + if (collapseValue) { + loopOp.setAttr(mlir::acc::LoopOp::getCollapseAttrName(), + firOpBuilder.getI64IntegerAttr(*collapseValue)); + } + } else if (std::get_if(&clause.u)) { + loopOp.setAttr(mlir::acc::LoopOp::getSeqAttrName(), + firOpBuilder.getUnitAttr()); + } else if (std::get_if( + &clause.u)) { + loopOp.setAttr(mlir::acc::LoopOp::getIndependentAttrName(), + firOpBuilder.getUnitAttr()); + } else if (std::get_if(&clause.u)) { + loopOp.setAttr(mlir::acc::LoopOp::getAutoAttrName(), + firOpBuilder.getUnitAttr()); + } + } + + // Place the insertion point to the start of the first block. + firOpBuilder.setInsertionPointToStart(&block); + } +} + void Fortran::lower::genOpenACCConstruct( - Fortran::lower::AbstractConverter &absConv, + Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenACCConstruct &accConstruct) { @@ -32,7 +218,7 @@ [&](const Fortran::parser::OpenACCCombinedConstruct &combinedConstruct) { TODO(); }, [&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) { - TODO(); + genACC(converter, eval, loopConstruct); }, [&](const Fortran::parser::OpenACCStandaloneConstruct &standaloneConstruct) { TODO(); },