diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -22,6 +22,9 @@ #include "mlir/Dialect/OpenACC/OpenACC.h" #include "llvm/Frontend/OpenACC/ACC.h.inc" +// Special value for * passed in device_type or gang clauses. +static constexpr std::int64_t starCst{-1}; + static const Fortran::parser::Name * getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) { const auto *dataRef{std::get_if(&designator.u)}; @@ -30,7 +33,7 @@ static void genObjectList(const Fortran::parser::AccObjectList &objectList, Fortran::lower::AbstractConverter &converter, - SmallVectorImpl &operands) { + SmallVectorImpl &operands) { for (const auto &accObject : objectList.v) { std::visit( Fortran::common::visitors{ @@ -53,8 +56,8 @@ genObjectListWithModifier(const Clause *x, Fortran::lower::AbstractConverter &converter, Fortran::parser::AccDataModifier::Modifier mod, - SmallVectorImpl &operandsWithModifier, - SmallVectorImpl &operands) { + SmallVectorImpl &operandsWithModifier, + SmallVectorImpl &operands) { const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v; const Fortran::parser::AccObjectList &accObjectList = std::get(listWithModifier.t); @@ -68,14 +71,14 @@ } } -static void addOperands(SmallVectorImpl &operands, +static void addOperands(SmallVectorImpl &operands, SmallVectorImpl &operandSegments, - const SmallVectorImpl &clauseOperands) { + const SmallVectorImpl &clauseOperands) { operands.append(clauseOperands.begin(), clauseOperands.end()); operandSegments.push_back(clauseOperands.size()); } -static void addOperand(SmallVectorImpl &operands, +static void addOperand(SmallVectorImpl &operands, SmallVectorImpl &operandSegments, const Value &clauseOperand) { if (clauseOperand) { @@ -89,7 +92,7 @@ template static Op createRegionOp(Fortran::lower::FirOpBuilder &builder, mlir::Location loc, - const SmallVectorImpl &operands, + const SmallVectorImpl &operands, const SmallVectorImpl &operandSegments) { llvm::ArrayRef argTy; Op op = builder.create(loc, argTy, operands); @@ -110,7 +113,7 @@ template static Op createSimpleOp(Fortran::lower::FirOpBuilder &builder, mlir::Location loc, - const SmallVectorImpl &operands, + const SmallVectorImpl &operands, const SmallVectorImpl &operandSegments) { llvm::ArrayRef argTy; Op op = builder.create(loc, argTy, operands); @@ -119,160 +122,230 @@ return op; } -static void genACC(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenACCLoopConstruct &loopConstruct) { - - const auto &beginLoopDirective = - std::get(loopConstruct.t); - const auto &loopDirective = - std::get(beginLoopDirective.t); +static void genAsyncClause(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::AccClause::Async *asyncClause, + mlir::Value &async, bool &addAsyncAttr) { + const auto &asyncClauseValue = asyncClause->v; + if (asyncClauseValue) { // async has a value. + async = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*asyncClauseValue))); + } else { + addAsyncAttr = true; + } +} - if (loopDirective.v == llvm::acc::ACCD_loop) { +static void genDeviceTypeClause( + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::AccClause::DeviceType *deviceTypeClause, + SmallVectorImpl &operands) { + const auto &deviceTypeValue = deviceTypeClause->v; + if (deviceTypeValue) { + for (const auto &scalarIntExpr : *deviceTypeValue) { + mlir::Value expr = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(scalarIntExpr))); + operands.push_back(expr); + } + } else { auto &firOpBuilder = converter.getFirOpBuilder(); - auto currentLocation = converter.getCurrentLocation(); + // * was passed as value and will be represented as a special constant. + mlir::Value star = firOpBuilder.createIntegerConstant( + converter.getCurrentLocation(), firOpBuilder.getIndexType(), starCst); + operands.push_back(star); + } +} - // Add attribute extracted from clauses. - const auto &accClauseList = - std::get(beginLoopDirective.t); +static void genIfClause(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::AccClause::If *ifClause, + mlir::Value &ifCond) { + auto &firOpBuilder = converter.getFirOpBuilder(); + Value cond = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v))); + ifCond = firOpBuilder.createConvert(converter.getCurrentLocation(), + firOpBuilder.getI1Type(), cond); +} - mlir::Value workerNum; - mlir::Value vectorLength; - mlir::Value gangNum; - mlir::Value gangStatic; - SmallVector tileOperands, privateOperands, reductionOperands; - std::int64_t executionMapping = mlir::acc::OpenACCExecMapping::NONE; - - // Lower clauses values mapped to operands. - for (const auto &clause : accClauseList.v) { - if (const auto *gangClause = - std::get_if(&clause.u)) { - if (gangClause->v) { - const Fortran::parser::AccGangArgument &x = *gangClause->v; - if (const auto &gangNumValue = - std::get>( - x.t)) { - gangNum = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(gangNumValue.value()))); - } - if (const auto &gangStaticValue = - std::get>(x.t)) { - const auto &expr = - std::get>( - gangStaticValue.value().t); - if (expr) { - gangStatic = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(*expr))); - } else { - // * was passed as value and will be represented as a -1 constant - // integer. - gangStatic = firOpBuilder.createIntegerConstant( - currentLocation, firOpBuilder.getIntegerType(32), - /* STAR */ -1); - } - } - } - executionMapping |= mlir::acc::OpenACCExecMapping::GANG; - } else if (const auto *workerClause = - std::get_if( - &clause.u)) { - if (workerClause->v) { - workerNum = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*workerClause->v))); - } - executionMapping |= mlir::acc::OpenACCExecMapping::WORKER; - } else if (const auto *vectorClause = - std::get_if( - &clause.u)) { - if (vectorClause->v) { - vectorLength = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*vectorClause->v))); +static void genWaitClause(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::AccClause::Wait *waitClause, + SmallVectorImpl &operands, + mlir::Value &waitDevnum, bool &addWaitAttr) { + const auto &waitClauseValue = waitClause->v; + if (waitClauseValue) { // wait has a value. + const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue; + const std::list &waitList = + std::get>(waitArg.t); + for (const Fortran::parser::ScalarIntExpr &value : waitList) { + mlir::Value v = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(value))); + operands.push_back(v); + } + + const std::optional &waitDevnumValue = + std::get>(waitArg.t); + if (waitDevnumValue) + waitDevnum = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*waitDevnumValue))); + } else { + addWaitAttr = true; + } +} + +static mlir::acc::LoopOp +createLoopOp(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::AccClauseList &accClauseList) { + auto &firOpBuilder = converter.getFirOpBuilder(); + auto currentLocation = converter.getCurrentLocation(); + + mlir::Value workerNum; + mlir::Value vectorNum; + mlir::Value gangNum; + mlir::Value gangStatic; + SmallVector tileOperands, privateOperands, reductionOperands; + std::int64_t executionMapping = mlir::acc::OpenACCExecMapping::NONE; + + for (const auto &clause : accClauseList.v) { + if (const auto *gangClause = + std::get_if(&clause.u)) { + if (gangClause->v) { + const Fortran::parser::AccGangArgument &x = *gangClause->v; + if (const auto &gangNumValue = + std::get>(x.t)) { + gangNum = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(gangNumValue.value()))); } - executionMapping |= mlir::acc::OpenACCExecMapping::VECTOR; - } else if (const auto *tileClause = - std::get_if(&clause.u)) { - const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v; - for (const auto &accTileExpr : accTileExprList.v) { + if (const auto &gangStaticValue = + std::get>(x.t)) { const auto &expr = - std::get>( - accTileExpr.t); + std::get>( + gangStaticValue.value().t); if (expr) { - tileOperands.push_back(fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(*expr)))); + gangStatic = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(*expr))); } else { - // * was passed as value and will be represented as a -1 constant - // integer. - mlir::Value tileStar = firOpBuilder.createIntegerConstant( - currentLocation, firOpBuilder.getIntegerType(32), - /* STAR */ -1); - tileOperands.push_back(tileStar); + // * was passed as value and will be represented as a special + // constant. + gangStatic = firOpBuilder.createIntegerConstant( + currentLocation, firOpBuilder.getIndexType(), starCst); } } - } else if (const auto *privateClause = - std::get_if( - &clause.u)) { - genObjectList(privateClause->v, converter, privateOperands); } - // Reduction clause is left out for the moment as the clause will probably - // end up having its own operation. + executionMapping |= mlir::acc::OpenACCExecMapping::GANG; + } else if (const auto *workerClause = + std::get_if(&clause.u)) { + if (workerClause->v) { + workerNum = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*workerClause->v))); + } + executionMapping |= mlir::acc::OpenACCExecMapping::WORKER; + } else if (const auto *vectorClause = + std::get_if(&clause.u)) { + if (vectorClause->v) { + vectorNum = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*vectorClause->v))); + } + executionMapping |= mlir::acc::OpenACCExecMapping::VECTOR; + } else if (const auto *tileClause = + std::get_if(&clause.u)) { + const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v; + for (const auto &accTileExpr : accTileExprList.v) { + const auto &expr = + std::get>( + accTileExpr.t); + if (expr) { + tileOperands.push_back(fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(*expr)))); + } else { + // * was passed as value and will be represented as a -1 constant + // integer. + mlir::Value tileStar = firOpBuilder.createIntegerConstant( + currentLocation, firOpBuilder.getIntegerType(32), + /* STAR */ -1); + tileOperands.push_back(tileStar); + } + } + } else if (const auto *privateClause = + std::get_if( + &clause.u)) { + genObjectList(privateClause->v, converter, privateOperands); } + // Reduction clause is left out for the moment as the clause will probably + // end up having its own operation. + } - // Prepare the operand segement size attribute and the operands value range. - SmallVector operands; - SmallVector operandSegments; - addOperand(operands, operandSegments, gangNum); - addOperand(operands, operandSegments, gangStatic); - addOperand(operands, operandSegments, workerNum); - addOperand(operands, operandSegments, vectorLength); - addOperands(operands, operandSegments, tileOperands); - addOperands(operands, operandSegments, privateOperands); - addOperands(operands, operandSegments, reductionOperands); - - auto loopOp = createRegionOp( - firOpBuilder, currentLocation, operands, operandSegments); - - loopOp.setAttr(mlir::acc::LoopOp::getExecutionMappingAttrName(), - firOpBuilder.getI64IntegerAttr(executionMapping)); - - // Lower clauses mapped to attributes - for (const auto &clause : accClauseList.v) { - if (const auto *collapseClause = - std::get_if(&clause.u)) { - const auto *expr = Fortran::semantics::GetExpr(collapseClause->v); - const auto collapseValue = Fortran::evaluate::ToInt64(*expr); - if (collapseValue) { - loopOp.setAttr(mlir::acc::LoopOp::getCollapseAttrName(), - firOpBuilder.getI64IntegerAttr(*collapseValue)); - } - } else if (std::get_if(&clause.u)) { - loopOp.setAttr(mlir::acc::LoopOp::getSeqAttrName(), - firOpBuilder.getUnitAttr()); - } else if (std::get_if( - &clause.u)) { - loopOp.setAttr(mlir::acc::LoopOp::getIndependentAttrName(), - firOpBuilder.getUnitAttr()); - } else if (std::get_if(&clause.u)) { - loopOp.setAttr(mlir::acc::LoopOp::getAutoAttrName(), - firOpBuilder.getUnitAttr()); + // Prepare the operand segement size attribute and the operands value range. + SmallVector operands; + SmallVector operandSegments; + addOperand(operands, operandSegments, gangNum); + addOperand(operands, operandSegments, gangStatic); + addOperand(operands, operandSegments, workerNum); + addOperand(operands, operandSegments, vectorNum); + addOperands(operands, operandSegments, tileOperands); + addOperands(operands, operandSegments, privateOperands); + addOperands(operands, operandSegments, reductionOperands); + + auto loopOp = createRegionOp( + firOpBuilder, currentLocation, operands, operandSegments); + + loopOp.setAttr(mlir::acc::LoopOp::getExecutionMappingAttrName(), + firOpBuilder.getI64IntegerAttr(executionMapping)); + + // Lower clauses mapped to attributes + for (const auto &clause : accClauseList.v) { + if (const auto *collapseClause = + std::get_if(&clause.u)) { + const auto *expr = Fortran::semantics::GetExpr(collapseClause->v); + const auto collapseValue = Fortran::evaluate::ToInt64(*expr); + if (collapseValue) { + loopOp.setAttr(mlir::acc::LoopOp::getCollapseAttrName(), + firOpBuilder.getI64IntegerAttr(*collapseValue)); } + } else if (std::get_if(&clause.u)) { + loopOp.setAttr(mlir::acc::LoopOp::getSeqAttrName(), + firOpBuilder.getUnitAttr()); + } else if (std::get_if( + &clause.u)) { + loopOp.setAttr(mlir::acc::LoopOp::getIndependentAttrName(), + firOpBuilder.getUnitAttr()); + } else if (std::get_if(&clause.u)) { + loopOp.setAttr(mlir::acc::LoopOp::getAutoAttrName(), + firOpBuilder.getUnitAttr()); } } + return loopOp; } -static void -genACCParallelOp(Fortran::lower::AbstractConverter &converter, +static void genACC(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenACCLoopConstruct &loopConstruct) { + + const auto &beginLoopDirective = + std::get(loopConstruct.t); + const auto &loopDirective = + std::get(beginLoopDirective.t); + + if (loopDirective.v == llvm::acc::ACCD_loop) { + const auto &accClauseList = + std::get(beginLoopDirective.t); + createLoopOp(converter, accClauseList); + } +} + +static mlir::acc::ParallelOp +createParallelOp(Fortran::lower::AbstractConverter &converter, const Fortran::parser::AccClauseList &accClauseList) { + + // Parallel operation operands mlir::Value async; mlir::Value numGangs; mlir::Value numWorkers; mlir::Value vectorLength; mlir::Value ifCond; mlir::Value selfCond; - SmallVector waitOperands, reductionOperands, copyOperands, + mlir::Value waitDevnum; + SmallVector waitOperands, reductionOperands, copyOperands, copyinOperands, copyinReadonlyOperands, copyoutOperands, copyoutZeroOperands, createOperands, createZeroOperands, noCreateOperands, - presentOperands, devicePtrOperands, attachOperands, privateOperands, - firstprivateOperands; + presentOperands, devicePtrOperands, attachOperands, firstprivateOperands, + privateOperands; // Async, wait and self clause have optional values but can be present with // no value as well. When there is no value, the op has an attribute to @@ -290,28 +363,11 @@ for (const auto &clause : accClauseList.v) { if (const auto *asyncClause = std::get_if(&clause.u)) { - const auto &asyncClauseValue = asyncClause->v; - if (asyncClauseValue) { // async has a value. - async = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*asyncClauseValue))); - } else { - addAsyncAttr = true; - } + genAsyncClause(converter, asyncClause, async, addAsyncAttr); } else if (const auto *waitClause = std::get_if(&clause.u)) { - const auto &waitClauseValue = waitClause->v; - if (waitClauseValue) { // wait has a value. - const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue; - const std::list &waitList = - std::get>(waitArg.t); - for (const Fortran::parser::ScalarIntExpr &value : waitList) { - Value v = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(value))); - waitOperands.push_back(v); - } - } else { - addWaitAttr = true; - } + genWaitClause(converter, waitClause, waitOperands, waitDevnum, + addWaitAttr); } else if (const auto *numGangsClause = std::get_if( &clause.u)) { @@ -329,10 +385,7 @@ *Fortran::semantics::GetExpr(vectorLengthClause->v))); } else if (const auto *ifClause = std::get_if(&clause.u)) { - Value cond = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v))); - ifCond = firOpBuilder.createConvert(currentLocation, - firOpBuilder.getI1Type(), cond); + genIfClause(converter, ifClause, ifCond); } else if (const auto *selfClause = std::get_if(&clause.u)) { if (selfClause->v) { @@ -392,7 +445,7 @@ } // Prepare the operand segement size attribute and the operands value range. - SmallVector operands; + SmallVector operands; SmallVector operandSegments; addOperand(operands, operandSegments, async); addOperands(operands, operandSegments, waitOperands); @@ -428,14 +481,23 @@ if (addSelfAttr) parallelOp.setAttr(mlir::acc::ParallelOp::getSelfAttrName(), firOpBuilder.getUnitAttr()); + + return parallelOp; +} + +static void +genACCParallelOp(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::AccClauseList &accClauseList) { + createParallelOp(converter, accClauseList); } static void genACCDataOp(Fortran::lower::AbstractConverter &converter, const Fortran::parser::AccClauseList &accClauseList) { mlir::Value ifCond; - SmallVector copyOperands, copyinOperands, copyinReadonlyOperands, - copyoutOperands, copyoutZeroOperands, createOperands, createZeroOperands, - noCreateOperands, presentOperands, deviceptrOperands, attachOperands; + SmallVector copyOperands, copyinOperands, + copyinReadonlyOperands, copyoutOperands, copyoutZeroOperands, + createOperands, createZeroOperands, noCreateOperands, presentOperands, + deviceptrOperands, attachOperands; auto &firOpBuilder = converter.getFirOpBuilder(); auto currentLocation = converter.getCurrentLocation(); @@ -446,10 +508,7 @@ for (const auto &clause : accClauseList.v) { if (const auto *ifClause = std::get_if(&clause.u)) { - Value cond = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v))); - ifCond = firOpBuilder.createConvert(currentLocation, - firOpBuilder.getI1Type(), cond); + genIfClause(converter, ifClause, ifCond); } else if (const auto *copyClause = std::get_if(&clause.u)) { genObjectList(copyClause->v, converter, copyOperands); @@ -491,7 +550,7 @@ } // Prepare the operand segement size attribute and the operands value range. - SmallVector operands; + SmallVector operands; SmallVector operandSegments; addOperand(operands, operandSegments, ifCond); addOperands(operands, operandSegments, copyOperands); @@ -532,8 +591,8 @@ genACCEnterDataOp(Fortran::lower::AbstractConverter &converter, const Fortran::parser::AccClauseList &accClauseList) { mlir::Value ifCond, async, waitDevnum; - SmallVector copyinOperands, createOperands, createZeroOperands, - attachOperands, waitOperands; + SmallVector copyinOperands, createOperands, + createZeroOperands, attachOperands, waitOperands; // Async, wait and self clause have optional values but can be present with // no value as well. When there is no value, the op has an attribute to @@ -550,40 +609,14 @@ for (const auto &clause : accClauseList.v) { if (const auto *ifClause = std::get_if(&clause.u)) { - mlir::Value cond = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v))); - ifCond = firOpBuilder.createConvert(currentLocation, - firOpBuilder.getI1Type(), cond); + genIfClause(converter, ifClause, ifCond); } else if (const auto *asyncClause = std::get_if(&clause.u)) { - const auto &asyncClauseValue = asyncClause->v; - if (asyncClauseValue) { // async has a value. - async = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*asyncClauseValue))); - } else { - addAsyncAttr = true; - } + genAsyncClause(converter, asyncClause, async, addAsyncAttr); } else if (const auto *waitClause = std::get_if(&clause.u)) { - const auto &waitClauseValue = waitClause->v; - if (waitClauseValue) { // wait has a value. - const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue; - const std::list &waitList = - std::get>(waitArg.t); - for (const Fortran::parser::ScalarIntExpr &value : waitList) { - mlir::Value v = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(value))); - waitOperands.push_back(v); - } - - const std::optional &waitDevnumValue = - std::get>(waitArg.t); - if (waitDevnumValue) - waitDevnum = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*waitDevnumValue))); - } else { - addWaitAttr = true; - } + genWaitClause(converter, waitClause, waitOperands, waitDevnum, + addWaitAttr); } else if (const auto *copyinClause = std::get_if(&clause.u)) { const Fortran::parser::AccObjectListWithModifier &listWithModifier = @@ -631,7 +664,7 @@ genACCExitDataOp(Fortran::lower::AbstractConverter &converter, const Fortran::parser::AccClauseList &accClauseList) { mlir::Value ifCond, async, waitDevnum; - SmallVector copyoutOperands, deleteOperands, detachOperands, + SmallVector copyoutOperands, deleteOperands, detachOperands, waitOperands; // Async and wait clause have optional values but can be present with @@ -650,40 +683,14 @@ for (const auto &clause : accClauseList.v) { if (const auto *ifClause = std::get_if(&clause.u)) { - Value cond = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v))); - ifCond = firOpBuilder.createConvert(currentLocation, - firOpBuilder.getI1Type(), cond); + genIfClause(converter, ifClause, ifCond); } else if (const auto *asyncClause = std::get_if(&clause.u)) { - const auto &asyncClauseValue = asyncClause->v; - if (asyncClauseValue) { // async has a value. - async = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*asyncClauseValue))); - } else { - addAsyncAttr = true; - } + genAsyncClause(converter, asyncClause, async, addAsyncAttr); } else if (const auto *waitClause = std::get_if(&clause.u)) { - const auto &waitClauseValue = waitClause->v; - if (waitClauseValue) { // wait has a value. - const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue; - const std::list &waitList = - std::get>(waitArg.t); - for (const Fortran::parser::ScalarIntExpr &value : waitList) { - Value v = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(value))); - waitOperands.push_back(v); - } - - const std::optional &waitDevnumValue = - std::get>(waitArg.t); - if (waitDevnumValue) - waitDevnum = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*waitDevnumValue))); - } else { - addWaitAttr = true; - } + genWaitClause(converter, waitClause, waitOperands, waitDevnum, + addWaitAttr); } else if (const auto *copyoutClause = std::get_if( &clause.u)) { @@ -730,7 +737,7 @@ genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter, const Fortran::parser::AccClauseList &accClauseList) { mlir::Value ifCond, deviceNum; - SmallVector deviceTypeOperands; + SmallVector deviceTypeOperands; auto &firOpBuilder = converter.getFirOpBuilder(); auto currentLocation = converter.getCurrentLocation(); @@ -741,10 +748,7 @@ for (const auto &clause : accClauseList.v) { if (const auto *ifClause = std::get_if(&clause.u)) { - mlir::Value cond = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v))); - ifCond = firOpBuilder.createConvert(currentLocation, - firOpBuilder.getI1Type(), cond); + genIfClause(converter, ifClause, ifCond); } else if (const auto *deviceNumClause = std::get_if( &clause.u)) { @@ -753,21 +757,7 @@ } else if (const auto *deviceTypeClause = std::get_if( &clause.u)) { - - const auto &deviceTypeValue = deviceTypeClause->v; - if (deviceTypeValue) { - for (const auto &scalarIntExpr : *deviceTypeValue) { - mlir::Value expr = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(scalarIntExpr))); - deviceTypeOperands.push_back(expr); - } - } else { - // * was passed as value and will be represented as a -1 constant - // integer. - mlir::Value star = firOpBuilder.createIntegerConstant( - currentLocation, firOpBuilder.getIntegerType(32), /* STAR */ -1); - deviceTypeOperands.push_back(star); - } + genDeviceTypeClause(converter, deviceTypeClause, deviceTypeOperands); } } @@ -785,7 +775,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter, const Fortran::parser::AccClauseList &accClauseList) { mlir::Value ifCond, async, waitDevnum; - SmallVector hostOperands, deviceOperands, waitOperands, + SmallVector hostOperands, deviceOperands, waitOperands, deviceTypeOperands; // Async and wait clause have optional values but can be present with @@ -804,58 +794,18 @@ for (const auto &clause : accClauseList.v) { if (const auto *ifClause = std::get_if(&clause.u)) { - mlir::Value cond = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v))); - ifCond = firOpBuilder.createConvert(currentLocation, - firOpBuilder.getI1Type(), cond); + genIfClause(converter, ifClause, ifCond); } else if (const auto *asyncClause = std::get_if(&clause.u)) { - const auto &asyncClauseValue = asyncClause->v; - if (asyncClauseValue) { // async has a value. - async = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*asyncClauseValue))); - } else { - addAsyncAttr = true; - } + genAsyncClause(converter, asyncClause, async, addAsyncAttr); } else if (const auto *waitClause = std::get_if(&clause.u)) { - const auto &waitClauseValue = waitClause->v; - if (waitClauseValue) { // wait has a value. - const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue; - const std::list &waitList = - std::get>(waitArg.t); - for (const Fortran::parser::ScalarIntExpr &value : waitList) { - mlir::Value v = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(value))); - waitOperands.push_back(v); - } - - const std::optional &waitDevnumValue = - std::get>(waitArg.t); - if (waitDevnumValue) - waitDevnum = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*waitDevnumValue))); - } else { - addWaitAttr = true; - } + genWaitClause(converter, waitClause, waitOperands, waitDevnum, + addWaitAttr); } else if (const auto *deviceTypeClause = std::get_if( &clause.u)) { - - const auto &deviceTypeValue = deviceTypeClause->v; - if (deviceTypeValue) { - for (const auto &scalarIntExpr : *deviceTypeValue) { - mlir::Value expr = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(scalarIntExpr))); - deviceTypeOperands.push_back(expr); - } - } else { - // * was passed as value and will be represented as a -1 constant - // integer. - mlir::Value star = firOpBuilder.createIntegerConstant( - currentLocation, firOpBuilder.getIntegerType(32), /* STAR */ -1); - deviceTypeOperands.push_back(star); - } + genDeviceTypeClause(converter, deviceTypeClause, deviceTypeOperands); } else if (const auto *hostClause = std::get_if(&clause.u)) { genObjectList(hostClause->v, converter, hostOperands); @@ -955,19 +905,10 @@ for (const auto &clause : accClauseList.v) { if (const auto *ifClause = std::get_if(&clause.u)) { - mlir::Value cond = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v))); - ifCond = firOpBuilder.createConvert(currentLocation, - firOpBuilder.getI1Type(), cond); + genIfClause(converter, ifClause, ifCond); } else if (const auto *asyncClause = std::get_if(&clause.u)) { - const auto &asyncClauseValue = asyncClause->v; - if (asyncClauseValue) { // async has a value. - async = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*asyncClauseValue))); - } else { - addAsyncAttr = true; - } + genAsyncClause(converter, asyncClause, async, addAsyncAttr); } }