diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -15,6 +15,7 @@ #include "flang/Lower/Bridge.h" #include "flang/Lower/FIRBuilder.h" #include "flang/Lower/PFTBuilder.h" +#include "flang/Lower/Support/BoxValue.h" #include "flang/Parser/parse-tree.h" #include "flang/Semantics/tools.h" #include "mlir/Dialect/OpenACC/OpenACC.h" @@ -48,6 +49,26 @@ } } +template +static void +genObjectListWithModifier(const Clause *x, + Fortran::lower::AbstractConverter &converter, + Fortran::parser::AccDataModifier::Modifier mod, + SmallVectorImpl &operandsWithModifier, + SmallVectorImpl &operands) { + const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v; + const Fortran::parser::AccObjectList &accObjectList = + std::get(listWithModifier.t); + const auto &modifier = + std::get>( + listWithModifier.t); + if (modifier && (*modifier).v == mod) { + genObjectList(accObjectList, converter, operandsWithModifier); + } else { + genObjectList(accObjectList, converter, operands); + } +} + static void addOperands(SmallVectorImpl &operands, SmallVectorImpl &operandSegments, const SmallVectorImpl &clauseOperands) { @@ -227,6 +248,193 @@ } } +static void +genACCParallelOp(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::AccClauseList &accClauseList) { + mlir::Value async; + mlir::Value numGangs; + mlir::Value numWorkers; + mlir::Value vectorLength; + mlir::Value ifCond; + mlir::Value selfCond; + SmallVector waitOperands, reductionOperands, copyOperands, + copyinOperands, copyinReadonlyOperands, copyoutOperands, + copyoutZeroOperands, createOperands, createZeroOperands, noCreateOperands, + presentOperands, devicePtrOperands, attachOperands, privateOperands, + firstprivateOperands; + + // Async, wait and self clause have optional values but can be present with + // no value as well. When there is no value, the op has an attribute to + // represent the clause. + bool addAsyncAttr = false; + bool addWaitAttr = false; + bool addSelfAttr = false; + + auto &firOpBuilder = converter.getFirOpBuilder(); + auto currentLocation = converter.getCurrentLocation(); + + // Lower clauses values mapped to operands. + // Keep track of each group of operands separatly as clauses can appear + // more than once. + for (const auto &clause : accClauseList.v) { + if (const auto *asyncClause = + std::get_if(&clause.u)) { + const auto &asyncClauseValue = asyncClause->v; + if (asyncClauseValue) { // async has a value. + async = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*asyncClauseValue))); + } else { + addAsyncAttr = true; + } + } else if (const auto *waitClause = + std::get_if(&clause.u)) { + const auto &waitClauseValue = waitClause->v; + if (waitClauseValue) { // wait has a value. + const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue; + const std::list &waitList = + std::get>(waitArg.t); + for (const Fortran::parser::ScalarIntExpr &value : waitList) { + Value v = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(value))); + waitOperands.push_back(v); + } + } else { + addWaitAttr = true; + } + } else if (const auto *numGangsClause = + std::get_if( + &clause.u)) { + numGangs = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(numGangsClause->v))); + } else if (const auto *numWorkersClause = + std::get_if( + &clause.u)) { + numWorkers = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(numWorkersClause->v))); + } else if (const auto *vectorLengthClause = + std::get_if( + &clause.u)) { + vectorLength = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(vectorLengthClause->v))); + } else if (const auto *ifClause = + std::get_if(&clause.u)) { + Value cond = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v))); + ifCond = firOpBuilder.createConvert(currentLocation, + firOpBuilder.getI1Type(), cond); + } else if (const auto *selfClause = + std::get_if(&clause.u)) { + if (selfClause->v) { + Value cond = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*(selfClause->v)))); + selfCond = firOpBuilder.createConvert(currentLocation, + firOpBuilder.getI1Type(), cond); + } else { + addSelfAttr = true; + } + } else if (const auto *copyClause = + std::get_if(&clause.u)) { + genObjectList(copyClause->v, converter, copyOperands); + } else if (const auto *copyinClause = + std::get_if(&clause.u)) { + genObjectListWithModifier( + copyinClause, converter, + Fortran::parser::AccDataModifier::Modifier::ReadOnly, + copyinReadonlyOperands, copyinOperands); + } else if (const auto *copyoutClause = + std::get_if( + &clause.u)) { + genObjectListWithModifier( + copyoutClause, converter, + Fortran::parser::AccDataModifier::Modifier::Zero, copyoutZeroOperands, + copyoutOperands); + } else if (const auto *createClause = + std::get_if(&clause.u)) { + genObjectListWithModifier( + createClause, converter, + Fortran::parser::AccDataModifier::Modifier::Zero, createZeroOperands, + createOperands); + } else if (const auto *noCreateClause = + std::get_if( + &clause.u)) { + genObjectList(noCreateClause->v, converter, noCreateOperands); + } else if (const auto *presentClause = + std::get_if( + &clause.u)) { + genObjectList(presentClause->v, converter, presentOperands); + } else if (const auto *devicePtrClause = + std::get_if( + &clause.u)) { + genObjectList(devicePtrClause->v, converter, devicePtrOperands); + } else if (const auto *attachClause = + std::get_if(&clause.u)) { + genObjectList(attachClause->v, converter, attachOperands); + } else if (const auto *privateClause = + std::get_if( + &clause.u)) { + genObjectList(privateClause->v, converter, privateOperands); + } else if (const auto *firstprivateClause = + std::get_if( + &clause.u)) { + genObjectList(firstprivateClause->v, converter, firstprivateOperands); + } + } + + // Prepare the operand segement size attribute and the operands value range. + SmallVector operands; + SmallVector operandSegments; + addOperand(operands, operandSegments, async); + addOperands(operands, operandSegments, waitOperands); + addOperand(operands, operandSegments, numGangs); + addOperand(operands, operandSegments, numWorkers); + addOperand(operands, operandSegments, vectorLength); + addOperand(operands, operandSegments, ifCond); + addOperand(operands, operandSegments, selfCond); + addOperands(operands, operandSegments, reductionOperands); + addOperands(operands, operandSegments, copyOperands); + addOperands(operands, operandSegments, copyinOperands); + addOperands(operands, operandSegments, copyinReadonlyOperands); + addOperands(operands, operandSegments, copyoutOperands); + addOperands(operands, operandSegments, copyoutZeroOperands); + addOperands(operands, operandSegments, createOperands); + addOperands(operands, operandSegments, createZeroOperands); + addOperands(operands, operandSegments, noCreateOperands); + addOperands(operands, operandSegments, presentOperands); + addOperands(operands, operandSegments, devicePtrOperands); + addOperands(operands, operandSegments, attachOperands); + addOperands(operands, operandSegments, privateOperands); + addOperands(operands, operandSegments, firstprivateOperands); + + auto parallelOp = createRegionOp( + firOpBuilder, currentLocation, operands, operandSegments); + + if (addAsyncAttr) + parallelOp.setAttr(mlir::acc::ParallelOp::getAsyncAttrName(), + firOpBuilder.getUnitAttr()); + if (addWaitAttr) + parallelOp.setAttr(mlir::acc::ParallelOp::getWaitAttrName(), + firOpBuilder.getUnitAttr()); + if (addSelfAttr) + parallelOp.setAttr(mlir::acc::ParallelOp::getSelfAttrName(), + firOpBuilder.getUnitAttr()); +} + +static void +genACC(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenACCBlockConstruct &blockConstruct) { + const auto &beginBlockDirective = + std::get(blockConstruct.t); + const auto &blockDirective = + std::get(beginBlockDirective.t); + const auto &accClauseList = + std::get(beginBlockDirective.t); + + if (blockDirective.v == llvm::acc::ACCD_parallel) { + genACCParallelOp(converter, accClauseList); + } +} + void Fortran::lower::genOpenACCConstruct( Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, @@ -235,7 +443,7 @@ std::visit( common::visitors{ [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) { - TODO(); + genACC(converter, eval, blockConstruct); }, [&](const Fortran::parser::OpenACCCombinedConstruct &combinedConstruct) { TODO(); },