diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -30,20 +30,17 @@ static void genObjectList(const Fortran::parser::AccObjectList &objectList, Fortran::lower::AbstractConverter &converter, - std::int32_t &objectsCount, - SmallVector &operands) { + SmallVectorImpl &operands) { for (const auto &accObject : objectList.v) { std::visit( Fortran::common::visitors{ [&](const Fortran::parser::Designator &designator) { if (const auto *name = getDesignatorNameIfDataRef(designator)) { - ++objectsCount; const auto variable = converter.getSymbolAddress(*name->symbol); operands.push_back(variable); } }, [&](const Fortran::parser::Name &name) { - ++objectsCount; const auto variable = converter.getSymbolAddress(*name.symbol); operands.push_back(variable); }}, @@ -51,6 +48,45 @@ } } +static void addOperands(SmallVectorImpl &operands, + SmallVectorImpl &operandSegments, + const SmallVectorImpl &clauseOperands) { + operands.append(clauseOperands.begin(), clauseOperands.end()); + operandSegments.push_back(clauseOperands.size()); +} + +static void addOperand(SmallVectorImpl &operands, + SmallVectorImpl &operandSegments, + const Value &clauseOperand) { + if (clauseOperand) { + operands.push_back(clauseOperand); + operandSegments.push_back(1); + } else { + operandSegments.push_back(0); + } +} + +template +static Op createRegionOp(Fortran::lower::FirOpBuilder &builder, + mlir::Location loc, + const SmallVectorImpl &operands, + const SmallVectorImpl &operandSegments) { + llvm::ArrayRef argTy; + Op op = builder.create(loc, argTy, operands); + builder.createBlock(&op.getRegion()); + auto &block = op.getRegion().back(); + builder.setInsertionPointToStart(&block); + builder.create(loc); + + op.setAttr(Op::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr(operandSegments)); + + // Place the insertion point to the start of the first block. + builder.setInsertionPointToStart(&block); + + return op; +} + static void genACC(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenACCLoopConstruct &loopConstruct) { @@ -73,11 +109,8 @@ mlir::Value vectorLength; mlir::Value gangNum; mlir::Value gangStatic; - std::int32_t tileOperands = 0; - std::int32_t privateOperands = 0; - std::int32_t reductionOperands = 0; + SmallVector tileOperands, privateOperands, reductionOperands; std::int64_t executionMapping = mlir::acc::OpenACCExecMapping::NONE; - SmallVector operands; // Lower clauses values mapped to operands. for (const auto &clause : accClauseList.v) { @@ -90,7 +123,6 @@ x.t)) { gangNum = converter.genExprValue( *Fortran::semantics::GetExpr(gangNumValue.value())); - operands.push_back(gangNum); } if (const auto &gangStaticValue = std::get>(x.t)) { @@ -107,7 +139,6 @@ currentLocation, firOpBuilder.getIntegerType(32), /* STAR */ -1); } - operands.push_back(gangStatic); } } executionMapping |= mlir::acc::OpenACCExecMapping::GANG; @@ -117,7 +148,6 @@ if (workerClause->v) { workerNum = converter.genExprValue( *Fortran::semantics::GetExpr(*workerClause->v)); - operands.push_back(workerNum); } executionMapping |= mlir::acc::OpenACCExecMapping::WORKER; } else if (const auto *vectorClause = @@ -126,7 +156,6 @@ if (vectorClause->v) { vectorLength = converter.genExprValue( *Fortran::semantics::GetExpr(*vectorClause->v)); - operands.push_back(vectorLength); } executionMapping |= mlir::acc::OpenACCExecMapping::VECTOR; } else if (const auto *tileClause = @@ -136,9 +165,8 @@ const auto &expr = std::get>( accTileExpr.t); - ++tileOperands; if (expr) { - operands.push_back( + tileOperands.push_back( converter.genExprValue(*Fortran::semantics::GetExpr(*expr))); } else { // * was passed as value and will be represented as a -1 constant @@ -146,33 +174,31 @@ mlir::Value tileStar = firOpBuilder.createIntegerConstant( currentLocation, firOpBuilder.getIntegerType(32), /* STAR */ -1); - operands.push_back(tileStar); + tileOperands.push_back(tileStar); } } } else if (const auto *privateClause = std::get_if( &clause.u)) { - const Fortran::parser::AccObjectList &accObjectList = privateClause->v; - genObjectList(accObjectList, converter, privateOperands, operands); + genObjectList(privateClause->v, converter, privateOperands); } // Reduction clause is left out for the moment as the clause will probably // end up having its own operation. } - auto loopOp = firOpBuilder.create(currentLocation, argTy, - operands); - - firOpBuilder.createBlock(&loopOp.getRegion()); - auto &block = loopOp.getRegion().back(); - firOpBuilder.setInsertionPointToStart(&block); - // ensure the block is well-formed. - firOpBuilder.create(currentLocation); + // Prepare the operand segement size attribute and the operands value range. + SmallVector operands; + SmallVector operandSegments; + addOperand(operands, operandSegments, gangNum); + addOperand(operands, operandSegments, gangStatic); + addOperand(operands, operandSegments, workerNum); + addOperand(operands, operandSegments, vectorLength); + addOperands(operands, operandSegments, tileOperands); + addOperands(operands, operandSegments, privateOperands); + addOperands(operands, operandSegments, reductionOperands); - loopOp.setAttr(mlir::acc::LoopOp::getOperandSegmentSizeAttr(), - firOpBuilder.getI32VectorAttr( - {gangNum ? 1 : 0, gangStatic ? 1 : 0, workerNum ? 1 : 0, - vectorLength ? 1 : 0, tileOperands, privateOperands, - reductionOperands})); + auto loopOp = createRegionOp( + firOpBuilder, currentLocation, operands, operandSegments); loopOp.setAttr(mlir::acc::LoopOp::getExecutionMappingAttrName(), firOpBuilder.getI64IntegerAttr(executionMapping)); @@ -199,9 +225,6 @@ firOpBuilder.getUnitAttr()); } } - - // Place the insertion point to the start of the first block. - firOpBuilder.setInsertionPointToStart(&block); } }