diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -101,6 +101,7 @@ } } +/// Generate the acc.bounds operation from the descriptor information. static llvm::SmallVector genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc, Fortran::lower::AbstractConverter &converter, @@ -126,6 +127,37 @@ return bounds; } +/// Generate acc.bounds operation for base array without any subscripts +/// provided. +static llvm::SmallVector +genBaseBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::Name &name, mlir::Value baseAddr) { + mlir::Type idxTy = builder.getIndexType(); + mlir::Type boundTy = builder.getType(); + fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(*name.symbol); + llvm::SmallVector bounds; + + llvm::errs() << dataExv.rank() << "\n"; + if (dataExv.rank() == 0) + return bounds; + + for (std::size_t dim = 0; dim < dataExv.rank(); ++dim) { + mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0); + mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); + mlir::Value lbound = + fir::factory::readLowerBound(builder, loc, dataExv, dim, one); + mlir::Value extent = fir::factory::readExtent(builder, loc, dataExv, dim); + extent = builder.create(loc, extent, lbound); + mlir::Value bound = builder.create( + loc, boundTy, zero, mlir::Value(), extent, one, false, lbound); + bounds.push_back(bound); + } + return bounds; +} + +/// Generate acc.bounds operations for an array section when subscripts are +/// provided. static llvm::SmallVector genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, Fortran::lower::AbstractConverter &converter, @@ -346,6 +378,10 @@ bounds = genBoundsOpsFromBox(builder, operandLocation, converter, *name.symbol, baseAddr, (*expr).Rank()); + if (fir::unwrapRefType(baseAddr.getType()) + .isa()) + bounds = genBaseBoundsOps(builder, operandLocation, + converter, name, baseAddr); createOpAndAddOperand(baseAddr, name.ToString(), operandLocation, bounds); } else { // Unsupported diff --git a/flang/test/Lower/OpenACC/acc-enter-data.f90 b/flang/test/Lower/OpenACC/acc-enter-data.f90 --- a/flang/test/Lower/OpenACC/acc-enter-data.f90 +++ b/flang/test/Lower/OpenACC/acc-enter-data.f90 @@ -137,6 +137,7 @@ !CHECK-LABEL: func.func @_QPacc_enter_data_dummy !CHECK-SAME: %[[A:.*]]: !fir.ref> {fir.bindc_name = "a"}, %[[B:.*]]: !fir.ref> {fir.bindc_name = "b"}, %[[N:.*]]: !fir.ref {fir.bindc_name = "n"}, %[[M:.*]]: !fir.ref {fir.bindc_name = "m"} +!CHECK: %[[C10:.*]] = arith.constant 10 : index !CHECK: %[[LOAD_N:.*]] = fir.load %[[N]] : !fir.ref !CHECK: %[[N_I64:.*]] = fir.convert %[[LOAD_N]] : (i32) -> i64 !CHECK: %[[N_IDX:.*]] = fir.convert %[[N_I64]] : (i64) -> index @@ -150,6 +151,22 @@ !CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[M_N_1]], %[[C0]] : index !CHECK: %[[EXT_B:.*]] = arith.select %[[CMP]], %[[M_N_1]], %[[C0]] : index + !$acc enter data create(a) +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[C1:.*]] = arith.constant 1 : index +!CHECK: %[[EXT:.*]] = arith.subi %[[C10]], %[[C1]] : index +!CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[C0]] : index) extent(%[[EXT]] : index) stride(%[[C1]] : index) startIdx(%[[C1]] : index) +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(b) +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[C1:.*]] = arith.constant 1 : index +!CHECK: %[[EXT:.*]] = arith.subi %[[EXT_B]], %[[N_IDX]] : index +!CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[C0]] : index) extent(%[[EXT]] : index) stride(%[[C1]] : index) startIdx(%[[N_IDX]] : index) +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[B]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "b", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + !$acc enter data create(a(5:10)) !CHECK: %[[LB1:.*]] = arith.constant 4 : index !CHECK: %[[UB1:.*]] = arith.constant 9 : index