diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -2305,6 +2305,53 @@ createSimpleOp(firOpBuilder, currentLocation, operands, operandSegments); } +void genACCSetOp(Fortran::lower::AbstractConverter &converter, + mlir::Location currentLocation, + const Fortran::parser::AccClauseList &accClauseList) { + mlir::Value ifCond, deviceNum, defaultAsync; + llvm::SmallVector deviceTypeOperands; + + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + Fortran::lower::StatementContext stmtCtx; + + // Lower clauses values mapped to operands. + // Keep track of each group of operands separately as clauses can appear + // more than once. + for (const Fortran::parser::AccClause &clause : accClauseList.v) { + mlir::Location clauseLocation = converter.genLocation(clause.source); + if (const auto *ifClause = + std::get_if(&clause.u)) { + genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx); + } else if (const auto *defaultAsyncClause = + std::get_if( + &clause.u)) { + defaultAsync = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(defaultAsyncClause->v), stmtCtx)); + } else if (const auto *deviceNumClause = + std::get_if( + &clause.u)) { + deviceNum = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(deviceNumClause->v), stmtCtx)); + } else if (const auto *deviceTypeClause = + std::get_if( + &clause.u)) { + genDeviceTypeClause(converter, clauseLocation, deviceTypeClause, + deviceTypeOperands, stmtCtx); + } + } + + // Prepare the operand segment size attribute and the operands value range. + llvm::SmallVector operands; + llvm::SmallVector operandSegments; + addOperands(operands, operandSegments, deviceTypeOperands); + addOperand(operands, operandSegments, defaultAsync); + addOperand(operands, operandSegments, deviceNum); + addOperand(operands, operandSegments, ifCond); + + createSimpleOp(firOpBuilder, currentLocation, operands, + operandSegments); +} + static void genACCUpdateOp(Fortran::lower::AbstractConverter &converter, mlir::Location currentLocation, @@ -2425,7 +2472,7 @@ genACCInitShutdownOp(converter, currentLocation, accClauseList); } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_set) { - TODO(currentLocation, "OpenACC set directive not lowered yet!"); + genACCSetOp(converter, currentLocation, accClauseList); } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_update) { genACCUpdateOp(converter, currentLocation, semanticsContext, stmtCtx, accClauseList); diff --git a/flang/test/Lower/OpenACC/acc-set.f90 b/flang/test/Lower/OpenACC/acc-set.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenACC/acc-set.f90 @@ -0,0 +1,40 @@ +! This test checks lowering of OpenACC set directive. + +! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s + +program test_acc_set + logical :: l + +!$acc set default_async(1) + +!$acc set default_async(1) if(l) + +!$acc set device_num(0) + +!$acc set device_type(*) + +!$acc set device_type(0) + +end + +! CHECK-LABEL: func.func @_QQmain() +! CHECK: %[[L:.*]] = fir.alloca !fir.logical<4> {bindc_name = "l", uniq_name = "_QFEl"} + +! CHECK: %[[C1:.*]] = arith.constant 1 : i32 +! CHECK: acc.set default_async(%[[C1]] : i32) + +! CHECK: %[[C1:.*]] = arith.constant 1 : i32 +! CHECK: %[[LOAD_L:.*]] = fir.load %[[L]] : !fir.ref> +! CHECK: %[[CONV_L:.*]] = fir.convert %[[LOAD_L]] : (!fir.logical<4>) -> i1 +! CHECK: acc.set default_async(%[[C1]] : i32) if(%[[CONV_L]]) + +! CHECK: %[[C0:.*]] = arith.constant 0 : i32 +! CHECK: acc.set device_num(%[[C0]] : i32) + +! CHECK: %[[C_1:.*]] = arith.constant -1 : index +! CHECK: acc.set device_type(%[[C_1]] : index) + +! CHECK: %[[C0:.*]] = arith.constant 0 : i32 +! CHECK: acc.set device_type(%[[C0]] : i32) + +