diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h --- a/flang/include/flang/Lower/OpenACC.h +++ b/flang/include/flang/Lower/OpenACC.h @@ -51,6 +51,11 @@ struct Evaluation; } // namespace pft +static constexpr llvm::StringRef declarePreDeallocSuffix = + "_acc_declare_update_desc_pre_dealloc"; +static constexpr llvm::StringRef declarePostDeallocSuffix = + "_acc_declare_update_desc_post_dealloc"; + void genOpenACCConstruct(AbstractConverter &, Fortran::semantics::SemanticsContext &, pft::Evaluation &, const parser::OpenACCConstruct &); diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -2392,6 +2392,86 @@ modBuilder.setInsertionPointAfter(registerFuncOp); } +static mlir::func::FuncOp createDeclareFunc(mlir::OpBuilder &modBuilder, + fir::FirOpBuilder &builder, + mlir::Location loc, + llvm::StringRef funcName) { + auto funcTy = mlir::FunctionType::get(modBuilder.getContext(), {}, {}); + auto funcOp = modBuilder.create(loc, funcName, funcTy); + funcOp.setVisibility(mlir::SymbolTable::Visibility::Private); + builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), {}, {}); + builder.setInsertionPointToEnd(&funcOp.getRegion().back()); + builder.create(loc); + builder.setInsertionPointToStart(&funcOp.getRegion().back()); + return funcOp; +} + +/// Action to be performed on deallocation are split in two distinct functions. +/// - Pre deallocation function includes all the action to be performed before +/// the actual deallocation is done on the host side. +/// - Post deallocation function includes update to the descriptor. +template +static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder, + fir::FirOpBuilder &builder, + mlir::Location loc, + fir::GlobalOp &globalOp, + mlir::acc::DataClause clause) { + + // Generate the pre dealloc function. + std::stringstream preDeallocFuncName; + preDeallocFuncName << globalOp.getSymName().str() + << Fortran::lower::declarePreDeallocSuffix.str(); + auto preDeallocOp = + createDeclareFunc(modBuilder, builder, loc, preDeallocFuncName.str()); + fir::AddrOfOp addrOp = builder.create( + loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol()); + auto loadOp = builder.create(loc, addrOp.getResult()); + fir::BoxAddrOp boxAddrOp = builder.create(loc, loadOp); + addDeclareAttr(builder, boxAddrOp.getOperation(), clause); + + std::stringstream asFortran; + asFortran << Fortran::lower::mangle::demangleName(globalOp.getSymName()); + llvm::SmallVector bounds; + mlir::acc::GetDevicePtrOp entryOp = + createDataEntryOp( + builder, loc, boxAddrOp.getResult(), asFortran, bounds, + /*structured=*/false, /*implicit=*/false, clause, + boxAddrOp.getType()); + + builder.create( + loc, mlir::ValueRange(entryOp.getAccPtr())); + + mlir::Value varPtr; + if constexpr (std::is_same_v || + std::is_same_v) + varPtr = entryOp.getVarPtr(); + builder.create(entryOp.getLoc(), entryOp.getAccPtr(), varPtr, + entryOp.getBounds(), entryOp.getDataClause(), + /*structured=*/false, /*implicit=*/false, + builder.getStringAttr(*entryOp.getName())); + + // Generate the post dealloc function. + modBuilder.setInsertionPointAfter(preDeallocOp); + std::stringstream postDeallocFuncName; + postDeallocFuncName << globalOp.getSymName().str() + << Fortran::lower::declarePostDeallocSuffix.str(); + auto postDeallocOp = + createDeclareFunc(modBuilder, builder, loc, postDeallocFuncName.str()); + + addrOp = builder.create( + loc, fir::ReferenceType::get(globalOp.getType()), globalOp.getSymbol()); + asFortran << "_desc"; + mlir::acc::UpdateDeviceOp updateDeviceOp = + createDataEntryOp( + builder, loc, addrOp, asFortran, bounds, + /*structured=*/false, /*implicit=*/true, + mlir::acc::DataClause::acc_update_device, addrOp.getType()); + llvm::SmallVector operandSegments{0, 0, 0, 0, 0, 1}; + llvm::SmallVector operands{updateDeviceOp.getResult()}; + createSimpleOp(builder, loc, operands, operandSegments); + modBuilder.setInsertionPointAfter(postDeallocOp); +} + template static void genGlobalCtors(Fortran::lower::AbstractConverter &converter, mlir::OpBuilder &modBuilder, @@ -2422,6 +2502,9 @@ /*implicit=*/true); createRegisterFunc( modBuilder, builder, operandLocation, globalOp, clause); + if constexpr (!std::is_same_v) + createDeclareDeallocFunc( + modBuilder, builder, operandLocation, globalOp, clause); } else { createDeclareGlobalOp( diff --git a/flang/test/Lower/OpenACC/acc-declare.f90 b/flang/test/Lower/OpenACC/acc-declare.f90 --- a/flang/test/Lower/OpenACC/acc-declare.f90 +++ b/flang/test/Lower/OpenACC/acc-declare.f90 @@ -282,6 +282,23 @@ ! CHECK: return ! CHECK: } +! CHECK-LABEL: func.func private @_QMacc_declare_allocatable_testEdata1_acc_declare_update_desc_pre_dealloc() { +! CHECK: %[[GLOBAL_ADDR:.*]] = fir.address_of(@_QMacc_declare_allocatable_testEdata1) : !fir.ref>>> +! CHECK: %[[LOAD]] = fir.load %[[GLOBAL_ADDR]] : !fir.ref>>> +! CHECK: %[[BOXADDR:.*]] = fir.box_addr %[[LOAD]] {acc.declare = #acc.declare} : (!fir.box>>) -> !fir.heap> +! CHECK: %[[DEVPTR:.*]] = acc.getdeviceptr varPtr(%[[BOXADDR]] : !fir.heap>) -> !fir.heap> {dataClause = #acc, name = "data1", structured = false} +! CHECK: acc.declare_exit dataOperands(%[[DEVICEPTR]] : !fir.heap>) +! CHECK: acc.delete accPtr(%[[DEVPTR]] : !fir.heap>) {dataClause = #acc, name = "data1", structured = false} +! CHECK: return +! CHECK: } + +! CHECK-LABEL: func.func private @_QMacc_declare_allocatable_testEdata1_acc_declare_update_desc_post_dealloc() { +! CHECK: %[[GLOBAL_ADDR:.*]] = fir.address_of(@_QMacc_declare_allocatable_testEdata1) : !fir.ref>>> +! CHECK: %[[UPDATE:.*]] = acc.update_device varPtr(%[[GLOBAL_ADDR]] : !fir.ref>>>) -> !fir.ref>>> {implicit = true, name = "data1_desc", structured = false} +! CHECK: acc.update dataOperands(%[[UPDATE]] : !fir.ref>>>) +! CHECK: return +! CHECK: } + ! CHECK-LABEL: acc.global_dtor @_QMacc_declare_allocatable_testEdata1_acc_dtor { ! CHECK: %[[GLOBAL_ADDR:.*]] = fir.address_of(@_QMacc_declare_allocatable_testEdata1) {acc.declare = #acc.declare} : !fir.ref>>> ! CHECK: %[[DEVICEPTR:.*]] = acc.getdeviceptr varPtr(%[[GLOBAL_ADDR]] : !fir.ref>>>) -> !fir.ref>>> {dataClause = #acc, name = "data1", structured = false}