diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -387,6 +387,148 @@ builder.getContext(), clause))); } +static mlir::func::FuncOp +createDeclareFunc(mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder, + mlir::Location loc, llvm::StringRef funcName, + llvm::SmallVector argsTy = {}, + llvm::SmallVector locs = {}) { + auto funcTy = mlir::FunctionType::get(modBuilder.getContext(), argsTy, {}); + auto funcOp = modBuilder.create(loc, funcName, funcTy); + funcOp.setVisibility(mlir::SymbolTable::Visibility::Private); + builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy, + locs); + builder.setInsertionPointToEnd(&funcOp.getRegion().back()); + builder.create(loc); + builder.setInsertionPointToStart(&funcOp.getRegion().back()); + return funcOp; +} + +template +static Op +createSimpleOp(fir::FirOpBuilder &builder, mlir::Location loc, + const llvm::SmallVectorImpl &operands, + const llvm::SmallVectorImpl &operandSegments) { + llvm::ArrayRef argTy; + Op op = builder.create(loc, argTy, operands); + op->setAttr(Op::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr(operandSegments)); + return op; +} + +template +static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder, + fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Type descTy, + llvm::StringRef funcNamePrefix, + std::stringstream &asFortran, + mlir::acc::DataClause clause) { + auto crtInsPt = builder.saveInsertionPoint(); + std::stringstream registerFuncName; + registerFuncName << funcNamePrefix.str() + << Fortran::lower::declarePostAllocSuffix.str(); + + if (!mlir::isa(descTy)) + descTy = fir::ReferenceType::get(descTy); + auto registerFuncOp = createDeclareFunc( + modBuilder, builder, loc, registerFuncName.str(), {descTy}, {loc}); + + mlir::Value desc = + builder.create(loc, registerFuncOp.getArgument(0)); + fir::BoxAddrOp boxAddrOp = builder.create(loc, desc); + addDeclareAttr(builder, boxAddrOp.getOperation(), clause); + + llvm::SmallVector bounds; + EntryOp entryOp = createDataEntryOp( + builder, loc, boxAddrOp.getResult(), asFortran, bounds, + /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType()); + builder.create( + loc, mlir::ValueRange(entryOp.getAccPtr())); + + asFortran << "_desc"; + mlir::acc::UpdateDeviceOp updateDeviceOp = + createDataEntryOp( + builder, loc, registerFuncOp.getArgument(0), asFortran, bounds, + /*structured=*/false, /*implicit=*/true, + mlir::acc::DataClause::acc_update_device, descTy); + llvm::SmallVector operandSegments{0, 0, 0, 0, 0, 1}; + llvm::SmallVector operands{updateDeviceOp.getResult()}; + createSimpleOp(builder, loc, operands, operandSegments); + modBuilder.setInsertionPointAfter(registerFuncOp); + builder.restoreInsertionPoint(crtInsPt); +} + +template +static void createDeclareDeallocFuncWithArg( + mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Type descTy, llvm::StringRef funcNamePrefix, + std::stringstream &asFortran, mlir::acc::DataClause clause) { + auto crtInsPt = builder.saveInsertionPoint(); + // Generate the pre dealloc function. + std::stringstream preDeallocFuncName; + preDeallocFuncName << funcNamePrefix.str() + << Fortran::lower::declarePreDeallocSuffix.str(); + if (!mlir::isa(descTy)) + descTy = fir::ReferenceType::get(descTy); + auto preDeallocOp = createDeclareFunc( + modBuilder, builder, loc, preDeallocFuncName.str(), {descTy}, {loc}); + mlir::Value loadOp = + builder.create(loc, preDeallocOp.getArgument(0)); + fir::BoxAddrOp boxAddrOp = builder.create(loc, loadOp); + addDeclareAttr(builder, boxAddrOp.getOperation(), clause); + + llvm::SmallVector bounds; + mlir::acc::GetDevicePtrOp entryOp = + createDataEntryOp( + builder, loc, boxAddrOp.getResult(), asFortran, bounds, + /*structured=*/false, /*implicit=*/false, clause, + boxAddrOp.getType()); + builder.create( + loc, mlir::ValueRange(entryOp.getAccPtr())); + + mlir::Value varPtr; + if constexpr (std::is_same_v || + std::is_same_v) + varPtr = entryOp.getVarPtr(); + builder.create(entryOp.getLoc(), entryOp.getAccPtr(), varPtr, + entryOp.getBounds(), entryOp.getDataClause(), + /*structured=*/false, /*implicit=*/false, + builder.getStringAttr(*entryOp.getName())); + + // Generate the post dealloc function. + modBuilder.setInsertionPointAfter(preDeallocOp); + std::stringstream postDeallocFuncName; + postDeallocFuncName << funcNamePrefix.str() + << Fortran::lower::declarePostDeallocSuffix.str(); + auto postDeallocOp = createDeclareFunc( + modBuilder, builder, loc, postDeallocFuncName.str(), {descTy}, {loc}); + loadOp = builder.create(loc, postDeallocOp.getArgument(0)); + asFortran << "_desc"; + mlir::acc::UpdateDeviceOp updateDeviceOp = + createDataEntryOp( + builder, loc, loadOp, asFortran, bounds, + /*structured=*/false, /*implicit=*/true, + mlir::acc::DataClause::acc_update_device, loadOp.getType()); + llvm::SmallVector operandSegments{0, 0, 0, 0, 0, 1}; + llvm::SmallVector operands{updateDeviceOp.getResult()}; + createSimpleOp(builder, loc, operands, operandSegments); + modBuilder.setInsertionPointAfter(postDeallocOp); + builder.restoreInsertionPoint(crtInsPt); +} + +Fortran::semantics::Symbol & +getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) { + if (const auto *designator = + std::get_if(&accObject.u)) { + if (const auto *name = + Fortran::semantics::getDesignatorNameIfDataRef(*designator)) + return *name->symbol; + } else if (const auto *name = + std::get_if(&accObject.u)) { + return *name->symbol; + } + llvm::report_fatal_error("Could not find symbol"); +} + template static void genDataOperandOperations(const Fortran::parser::AccObjectList &objectList, @@ -408,11 +550,69 @@ bounds, structured, implicit, dataClause, baseAddr.getType()); dataOperands.push_back(op.getAccPtr()); - if (setDeclareAttr) - addDeclareAttr(builder, op.getVarPtr().getDefiningOp(), dataClause); } } +template +static void genDeclareDataOperandOperations( + const Fortran::parser::AccObjectList &objectList, + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::lower::StatementContext &stmtCtx, + llvm::SmallVectorImpl &dataOperands, + mlir::acc::DataClause dataClause, bool structured, bool implicit) { + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + for (const auto &accObject : objectList.v) { + llvm::SmallVector bounds; + std::stringstream asFortran; + mlir::Location operandLocation = genOperandLocation(converter, accObject); + mlir::Value baseAddr = gatherDataOperandAddrAndBounds( + converter, builder, semanticsContext, stmtCtx, accObject, + operandLocation, asFortran, bounds); + EntryOp op = createDataEntryOp( + builder, operandLocation, baseAddr, asFortran, bounds, structured, + implicit, dataClause, baseAddr.getType()); + dataOperands.push_back(op.getAccPtr()); + addDeclareAttr(builder, op.getVarPtr().getDefiningOp(), dataClause); + if (mlir::isa(fir::unwrapRefType(baseAddr.getType()))) { + mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion()); + modBuilder.setInsertionPointAfter(builder.getFunction()); + std::string prefix = + converter.mangleName(getSymbolFromAccObject(accObject)); + createDeclareAllocFuncWithArg( + modBuilder, builder, operandLocation, baseAddr.getType(), prefix, + asFortran, dataClause); + if constexpr (!std::is_same_v) + createDeclareDeallocFuncWithArg( + modBuilder, builder, operandLocation, baseAddr.getType(), prefix, + asFortran, dataClause); + } + } +} + +template +static void genDeclareDataOperandOperationsWithModifier( + const Clause *x, Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::lower::StatementContext &stmtCtx, + Fortran::parser::AccDataModifier::Modifier mod, + llvm::SmallVectorImpl &dataClauseOperands, + const mlir::acc::DataClause clause, + const mlir::acc::DataClause clauseWithModifier) { + const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v; + const auto &accObjectList = + std::get(listWithModifier.t); + const auto &modifier = + std::get>( + listWithModifier.t); + mlir::acc::DataClause dataClause = + (modifier && (*modifier).v == mod) ? clauseWithModifier : clause; + genDeclareDataOperandOperations( + accObjectList, converter, semanticsContext, stmtCtx, dataClauseOperands, + dataClause, + /*structured=*/true, /*implicit=*/false); +} + template static void genDataExitOperations(fir::FirOpBuilder &builder, llvm::SmallVector operands, @@ -1058,18 +1258,6 @@ return op; } -template -static Op -createSimpleOp(fir::FirOpBuilder &builder, mlir::Location loc, - const llvm::SmallVectorImpl &operands, - const llvm::SmallVectorImpl &operandSegments) { - llvm::ArrayRef argTy; - Op op = builder.create(loc, argTy, operands); - op->setAttr(Op::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr(operandSegments)); - return op; -} - static void genAsyncClause(Fortran::lower::AbstractConverter &converter, const Fortran::parser::AccClause::Async *asyncClause, mlir::Value &async, bool &addAsyncAttr, @@ -2349,20 +2537,6 @@ modBuilder.setInsertionPointAfter(declareGlobalOp); } -static mlir::func::FuncOp createDeclareFunc(mlir::OpBuilder &modBuilder, - fir::FirOpBuilder &builder, - mlir::Location loc, - llvm::StringRef funcName) { - auto funcTy = mlir::FunctionType::get(modBuilder.getContext(), {}, {}); - auto funcOp = modBuilder.create(loc, funcName, funcTy); - funcOp.setVisibility(mlir::SymbolTable::Visibility::Private); - builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), {}, {}); - builder.setInsertionPointToEnd(&funcOp.getRegion().back()); - builder.create(loc); - builder.setInsertionPointToStart(&funcOp.getRegion().back()); - return funcOp; -} - template static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder, @@ -2556,10 +2730,11 @@ if (const auto *copyClause = std::get_if(&clause.u)) { auto crtDataStart = dataClauseOperands.size(); - genDataOperandOperations( + genDeclareDataOperandOperations( copyClause->v, converter, semanticsContext, stmtCtx, dataClauseOperands, mlir::acc::DataClause::acc_copy, - /*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true); + /*structured=*/true, /*implicit=*/false); copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end()); } else if (const auto *createClause = @@ -2569,26 +2744,28 @@ const auto &accObjectList = std::get(listWithModifier.t); auto crtDataStart = dataClauseOperands.size(); - genDataOperandOperations( + genDeclareDataOperandOperations( accObjectList, converter, semanticsContext, stmtCtx, dataClauseOperands, mlir::acc::DataClause::acc_create, - /*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true); + /*structured=*/true, /*implicit=*/false); createEntryOperands.append(dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end()); } else if (const auto *presentClause = std::get_if( &clause.u)) { - genDataOperandOperations( + genDeclareDataOperandOperations( presentClause->v, converter, semanticsContext, stmtCtx, dataClauseOperands, mlir::acc::DataClause::acc_present, - /*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true); + /*structured=*/true, /*implicit=*/false); } else if (const auto *copyinClause = std::get_if(&clause.u)) { - genDataOperandOperationsWithModifier( + genDeclareDataOperandOperationsWithModifier( copyinClause, converter, semanticsContext, stmtCtx, Fortran::parser::AccDataModifier::Modifier::ReadOnly, dataClauseOperands, mlir::acc::DataClause::acc_copyin, - mlir::acc::DataClause::acc_copyin_readonly, /*setDeclareAttr=*/true); + mlir::acc::DataClause::acc_copyin_readonly); } else if (const auto *copyoutClause = std::get_if( &clause.u)) { @@ -2597,34 +2774,38 @@ const auto &accObjectList = std::get(listWithModifier.t); auto crtDataStart = dataClauseOperands.size(); - genDataOperandOperations( + genDeclareDataOperandOperations( accObjectList, converter, semanticsContext, stmtCtx, dataClauseOperands, mlir::acc::DataClause::acc_copyout, - /*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true); + /*structured=*/true, /*implicit=*/false); copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end()); } else if (const auto *devicePtrClause = std::get_if( &clause.u)) { - genDataOperandOperations( + genDeclareDataOperandOperations( devicePtrClause->v, converter, semanticsContext, stmtCtx, dataClauseOperands, mlir::acc::DataClause::acc_deviceptr, - /*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true); + /*structured=*/true, /*implicit=*/false); } else if (const auto *linkClause = std::get_if(&clause.u)) { - genDataOperandOperations( + genDeclareDataOperandOperations( linkClause->v, converter, semanticsContext, stmtCtx, dataClauseOperands, mlir::acc::DataClause::acc_declare_link, - /*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true); + /*structured=*/true, /*implicit=*/false); } else if (const auto *deviceResidentClause = std::get_if( &clause.u)) { auto crtDataStart = dataClauseOperands.size(); - genDataOperandOperations( + genDeclareDataOperandOperations( deviceResidentClause->v, converter, semanticsContext, stmtCtx, dataClauseOperands, mlir::acc::DataClause::acc_declare_device_resident, - /*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true); + /*structured=*/true, /*implicit=*/false); deviceResidentEntryOperands.append( dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end()); } else { diff --git a/flang/test/Lower/OpenACC/acc-declare.f90 b/flang/test/Lower/OpenACC/acc-declare.f90 --- a/flang/test/Lower/OpenACC/acc-declare.f90 +++ b/flang/test/Lower/OpenACC/acc-declare.f90 @@ -283,6 +283,36 @@ end subroutine +! CHECK-LABEL: func.func private @_QMacc_declareFacc_declare_allocateEa_acc_declare_update_desc_post_alloc( +! CHECK-SAME: %[[ARG0:.*]]: !fir.ref>>>) { +! CHECK: %[[LOAD:.*]] = fir.load %[[ARG0]] : !fir.ref>>> +! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[LOAD]] {acc.declare = #acc.declare} : (!fir.box>>) -> !fir.heap> +! CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[BOX_ADDR]] : !fir.heap>) -> !fir.heap> {name = "a", structured = false} +! CHECK: acc.declare_enter dataOperands(%[[CREATE]] : !fir.heap>) +! CHECK: %[[UPDATE:.*]] = acc.update_device varPtr(%[[ARG0]] : !fir.ref>>>) -> !fir.ref>>> {implicit = true, name = "a_desc", structured = false} +! CHECK: acc.update dataOperands(%[[UPDATE]] : !fir.ref>>>) +! CHECK: return +! CHECK: } + +! CHECK-LABEL: func.func private @_QMacc_declareFacc_declare_allocateEa_acc_declare_update_desc_pre_dealloc( +! CHECK-SAME: %[[ARG0:.*]]: !fir.ref>>>) { +! CHECK: %[[LOAD:.*]] = fir.load %[[ARG0]] : !fir.ref>>> +! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[LOAD]] {acc.declare = #acc.declare} : (!fir.box>>) -> !fir.heap> +! CHECK: %[[GETDEVICEPTR:.*]] = acc.getdeviceptr varPtr(%[[BOX_ADDR]] : !fir.heap>) -> !fir.heap> {dataClause = #acc, name = "a_desc", structured = false} +! CHECK: acc.declare_exit dataOperands(%[[GETDEVICEPTR]] : !fir.heap>) +! CHECK: acc.delete accPtr(%[[GETDEVICEPTR]] : !fir.heap>) {dataClause = #acc, name = "a_desc", structured = false} +! CHECK: return +! CHECK: } + +! CHECK-LABEL: func.func private @_QMacc_declareFacc_declare_allocateEa_acc_declare_update_desc_post_dealloc( +! CHECK-SAME: %[[ARG0:.*]]: !fir.ref>>>) { +! CHECK: %[[LOAD:.*]] = fir.load %[[ARG0]] : !fir.ref>>> +! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[LOAD]] : (!fir.box>>) -> !fir.heap> +! CHECK: %[[UPDATE:.*]] = acc.update_device varPtr(%[[BOX_ADDR]] : !fir.heap>) -> !fir.heap> {implicit = true, name = "a_desc_desc", structured = false} +! CHECK: acc.update dataOperands(%[[UPDATE]] : !fir.heap>) +! CHECK: return +! CHECK: } + end module module acc_declare_allocatable_test