diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -1089,8 +1089,12 @@ TODO(toLocation(), "CompilerDirective lowering"); } - void genFIR(const Fortran::parser::OpenACCConstruct &) { - TODO(toLocation(), "OpenACCConstruct lowering"); + void genFIR(const Fortran::parser::OpenACCConstruct &acc) { + mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); + genOpenACCConstruct(*this, getEval(), acc); + for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) + genFIR(e); + builder->restoreInsertionPoint(insertPt); } void genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &) { diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -120,6 +120,56 @@ return op; } +static void genAsyncClause(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::AccClause::Async *asyncClause, + mlir::Value &async, bool &addAsyncAttr, + Fortran::lower::StatementContext &stmtCtx) { + const auto &asyncClauseValue = asyncClause->v; + if (asyncClauseValue) { // async has a value. + async = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)); + } else { + addAsyncAttr = true; + } +} + +static void genIfClause(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::AccClause::If *ifClause, + mlir::Value &ifCond, + Fortran::lower::StatementContext &stmtCtx) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + Value cond = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(ifClause->v), stmtCtx)); + ifCond = firOpBuilder.createConvert(converter.getCurrentLocation(), + firOpBuilder.getI1Type(), cond); +} + +static void genWaitClause(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::AccClause::Wait *waitClause, + SmallVectorImpl &operands, + mlir::Value &waitDevnum, bool &addWaitAttr, + Fortran::lower::StatementContext &stmtCtx) { + const auto &waitClauseValue = waitClause->v; + if (waitClauseValue) { // wait has a value. + const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue; + const std::list &waitList = + std::get>(waitArg.t); + for (const Fortran::parser::ScalarIntExpr &value : waitList) { + mlir::Value v = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(value), stmtCtx)); + operands.push_back(v); + } + + const std::optional &waitDevnumValue = + std::get>(waitArg.t); + if (waitDevnumValue) + waitDevnum = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)); + } else { + addWaitAttr = true; + } +} + static void genACC(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenACCLoopConstruct &loopConstruct) { @@ -540,7 +590,7 @@ genACCEnterDataOp(Fortran::lower::AbstractConverter &converter, const Fortran::parser::AccClauseList &accClauseList) { mlir::Value ifCond, async, waitDevnum; - SmallVector copyinOperands, createOperands, createZeroOperands, + SmallVector copyinOperands, createOperands, createZeroOperands, attachOperands, waitOperands; // Async, wait and self clause have optional values but can be present with @@ -549,50 +599,24 @@ bool addAsyncAttr = false; bool addWaitAttr = false; - auto &firOpBuilder = converter.getFirOpBuilder(); - auto currentLocation = converter.getCurrentLocation(); + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); Fortran::lower::StatementContext stmtCtx; // Lower clauses values mapped to operands. // Keep track of each group of operands separatly as clauses can appear // more than once. - for (const auto &clause : accClauseList.v) { + for (const Fortran::parser::AccClause &clause : accClauseList.v) { if (const auto *ifClause = std::get_if(&clause.u)) { - mlir::Value cond = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(ifClause->v), stmtCtx)); - ifCond = firOpBuilder.createConvert(currentLocation, - firOpBuilder.getI1Type(), cond); + genIfClause(converter, ifClause, ifCond, stmtCtx); } else if (const auto *asyncClause = std::get_if(&clause.u)) { - const auto &asyncClauseValue = asyncClause->v; - if (asyncClauseValue) { // async has a value. - async = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)); - } else { - addAsyncAttr = true; - } + genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx); } else if (const auto *waitClause = std::get_if(&clause.u)) { - const auto &waitClauseValue = waitClause->v; - if (waitClauseValue) { // wait has a value. - const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue; - const std::list &waitList = - std::get>(waitArg.t); - for (const Fortran::parser::ScalarIntExpr &value : waitList) { - mlir::Value v = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(value), stmtCtx)); - waitOperands.push_back(v); - } - - const std::optional &waitDevnumValue = - std::get>(waitArg.t); - if (waitDevnumValue) - waitDevnum = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)); - } else { - addWaitAttr = true; - } + genWaitClause(converter, waitClause, waitOperands, waitDevnum, + addWaitAttr, stmtCtx); } else if (const auto *copyinClause = std::get_if(&clause.u)) { const Fortran::parser::AccObjectListWithModifier &listWithModifier = @@ -627,7 +651,7 @@ addOperands(operands, operandSegments, createZeroOperands); addOperands(operands, operandSegments, attachOperands); - auto enterDataOp = createSimpleOp( + mlir::acc::EnterDataOp enterDataOp = createSimpleOp( firOpBuilder, currentLocation, operands, operandSegments); if (addAsyncAttr) diff --git a/flang/test/Lower/OpenACC/acc-enter-data.f90 b/flang/test/Lower/OpenACC/acc-enter-data.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenACC/acc-enter-data.f90 @@ -0,0 +1,69 @@ +! This test checks lowering of OpenACC enter data directive. + +! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s + +subroutine acc_enter_data + integer :: async = 1 + real, dimension(10, 10) :: a, b, c + real, pointer :: d + logical :: ifCondition = .TRUE. + +!CHECK: [[A:%.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Ea"} +!CHECK: [[B:%.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Eb"} +!CHECK: [[C:%.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Ec"} +!CHECK: [[D:%.*]] = fir.alloca !fir.box> {bindc_name = "d", uniq_name = "{{.*}}Ed"} + + !$acc enter data create(a) +!CHECK: acc.enter_data create([[A]] : !fir.ref>){{$}} + + !$acc enter data create(a) if(.true.) +!CHECK: [[IF1:%.*]] = arith.constant true +!CHECK: acc.enter_data if([[IF1]]) create([[A]] : !fir.ref>){{$}} + + !$acc enter data create(a) if(ifCondition) +!CHECK: [[IFCOND:%.*]] = fir.load %{{.*}} : !fir.ref> +!CHECK: [[IF2:%.*]] = fir.convert [[IFCOND]] : (!fir.logical<4>) -> i1 +!CHECK: acc.enter_data if([[IF2]]) create([[A]] : !fir.ref>){{$}} + + !$acc enter data create(a) create(b) create(c) +!CHECK: acc.enter_data create([[A]], [[B]], [[C]] : !fir.ref>, !fir.ref>, !fir.ref>){{$}} + + !$acc enter data create(a) create(b) create(zero: c) +!CHECK: acc.enter_data create([[A]], [[B]] : !fir.ref>, !fir.ref>) create_zero([[C]] : !fir.ref>){{$}} + + !$acc enter data copyin(a) create(b) attach(d) +!CHECK: acc.enter_data copyin([[A]] : !fir.ref>) create([[B]] : !fir.ref>) attach([[D]] : !fir.ref>>){{$}} + + !$acc enter data create(a) async +!CHECK: acc.enter_data create([[A]] : !fir.ref>) attributes {async} + + !$acc enter data create(a) wait +!CHECK: acc.enter_data create([[A]] : !fir.ref>) attributes {wait} + + !$acc enter data create(a) async wait +!CHECK: acc.enter_data create([[A]] : !fir.ref>) attributes {async, wait} + + !$acc enter data create(a) async(1) +!CHECK: [[ASYNC1:%.*]] = arith.constant 1 : i32 +!CHECK: acc.enter_data async([[ASYNC1]] : i32) create([[A]] : !fir.ref>) + + !$acc enter data create(a) async(async) +!CHECK: [[ASYNC2:%.*]] = fir.load %{{.*}} : !fir.ref +!CHECK: acc.enter_data async([[ASYNC2]] : i32) create([[A]] : !fir.ref>) + + !$acc enter data create(a) wait(1) +!CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32 +!CHECK: acc.enter_data wait([[WAIT1]] : i32) create([[A]] : !fir.ref>) + + !$acc enter data create(a) wait(queues: 1, 2) +!CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32 +!CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32 +!CHECK: acc.enter_data wait([[WAIT2]], [[WAIT3]] : i32, i32) create([[A]] : !fir.ref>) + + !$acc enter data create(a) wait(devnum: 1: queues: 1, 2) +!CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32 +!CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32 +!CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32 +!CHECK: acc.enter_data wait_devnum([[WAIT6]] : i32) wait([[WAIT4]], [[WAIT5]] : i32, i32) create([[A]] : !fir.ref>) + +end subroutine acc_enter_data