diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -131,7 +131,8 @@ Fortran::lower::AbstractConverter &converter, Fortran::lower::StatementContext &stmtCtx, const std::list &subscripts, - std::stringstream &asFortran, const Fortran::parser::Name &name) { + std::stringstream &asFortran, const Fortran::parser::Name &name, + mlir::Value baseAddr) { int dimension = 0; mlir::Type idxTy = builder.getIndexType(); mlir::Type boundTy = builder.getType(); @@ -146,17 +147,16 @@ mlir::Value lbound, ubound, extent; std::optional lval, uval; mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); - fir::ExtendedValue dataExv = - converter.getSymbolExtendedValue(*name.symbol); mlir::Value baseLb = fir::factory::readLowerBound(builder, loc, dataExv, dimension, one); bool defaultLb = baseLb == one; mlir::Value stride; bool strideInBytes = false; - if (fir::getBase(dataExv).getType().isa()) { + + if (fir::unwrapRefType(baseAddr.getType()).isa()) { mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension); auto dimInfo = builder.create(loc, idxTy, idxTy, idxTy, - fir::getBase(dataExv), d); + baseAddr, d); stride = dimInfo.getByteStride(); strideInBytes = true; } @@ -255,24 +255,33 @@ if (!symAddr) llvm::report_fatal_error("could not retrieve symbol address"); - mlir::Type symTy = symAddr.getType(); - if (auto refTy = symTy.dyn_cast()) - symTy = refTy.getEleTy(); - - if (auto boxTy = - fir::unwrapRefType(symAddr.getType()).dyn_cast()) - if (boxTy.getEleTy() - .isa()) - TODO(loc, "pointer, allocatable and derived type box"); + if (auto boxTy = fir::unwrapRefType(symAddr.getType()) + .dyn_cast()) { + if (boxTy.getEleTy().isa()) + TODO(loc, "derived type"); + // Load the box when baseAddr is a `fir.ref>` or a + // `fir.ref>` type. + if (symAddr.getType().isa()) + return builder.create(loc, symAddr); + } return symAddr; }; auto createOpAndAddOperand = [&](mlir::Value baseAddr, llvm::StringRef name, mlir::Location loc, llvm::SmallVector &bounds) { - if (baseAddr.getType().isa()) - baseAddr = builder.create(loc, baseAddr); + if (auto boxTy = baseAddr.getType().dyn_cast()) { + // Get the actual data address when the descriptor is an allocatable or + // a pointer. + if (boxTy.getEleTy().isa()) { + mlir::Value boxAddr = builder.create( + loc, fir::ReferenceType::get(boxTy.getEleTy()), baseAddr); + baseAddr = builder.create(loc, boxAddr); + } else { // Get the address of the boxed value. + baseAddr = builder.create(loc, baseAddr); + } + } Op op = builder.create(loc, baseAddr.getType(), baseAddr); op.setNameAttr(builder.getStringAttr(name)); @@ -308,15 +317,15 @@ Fortran::parser::GetLastName(*dataRef); std::stringstream asFortran; asFortran << name.ToString(); + mlir::Value baseAddr = + getDataOperandBaseAddr(*name.symbol, operandLocation); if (!arrayElement->subscripts.empty()) { asFortran << '('; bounds = genBoundsOps(builder, operandLocation, converter, stmtCtx, arrayElement->subscripts, - asFortran, name); + asFortran, name, baseAddr); } asFortran << ')'; - mlir::Value baseAddr = - getDataOperandBaseAddr(*name.symbol, operandLocation); createOpAndAddOperand(baseAddr, asFortran.str(), operandLocation, bounds); } else if (Fortran::parser::Unwrap< @@ -332,7 +341,8 @@ mlir::Value baseAddr = getDataOperandBaseAddr(*name.symbol, operandLocation); llvm::SmallVector bounds; - if (baseAddr.getType().isa()) + if (fir::unwrapRefType(baseAddr.getType()) + .isa()) bounds = genBoundsOpsFromBox(builder, operandLocation, converter, *name.symbol, baseAddr, (*expr).Rank()); diff --git a/flang/test/Lower/OpenACC/acc-enter-data.f90 b/flang/test/Lower/OpenACC/acc-enter-data.f90 --- a/flang/test/Lower/OpenACC/acc-enter-data.f90 +++ b/flang/test/Lower/OpenACC/acc-enter-data.f90 @@ -41,10 +41,14 @@ !CHECK: %[[CREATE_C:.*]] = acc.create varPtr(%[[C]] : !fir.ref>) -> !fir.ref> {dataClause = 8 : i64, name = "c", structured = false} !CHECK: acc.enter_data dataOperands(%[[CREATE_A]], %[[CREATE_B]], %[[CREATE_C]] : !fir.ref>, !fir.ref>, !fir.ref>){{$}} - !$acc enter data copyin(a) create(b) + !$acc enter data copyin(a) create(b) attach(d) !CHECK: %[[COPYIN_A:.*]] = acc.copyin varPtr(%[[A]] : !fir.ref>) -> !fir.ref> {name = "a", structured = false} !CHECK: %[[CREATE_B:.*]] = acc.create varPtr(%[[B]] : !fir.ref>) -> !fir.ref> {name = "b", structured = false} -!CHECK: acc.enter_data dataOperands(%[[COPYIN_A]], %[[CREATE_B]] : !fir.ref>, !fir.ref>){{$}} +!CHECK: %[[BOX_D:.*]] = fir.load %[[D]] : !fir.ref>> +!CHECK: %[[BOX_ADDR_D:.*]] = fir.box_addr %[[BOX_D]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[D_PTR:.*]] = fir.load %[[BOX_ADDR_D]] : !fir.ref> +!CHECK: %[[ATTACH_D:.*]] = acc.attach varPtr(%[[D_PTR]] : !fir.ptr) -> !fir.ptr {name = "d", structured = false} +!CHECK: acc.enter_data dataOperands(%[[COPYIN_A]], %[[CREATE_B]], %[[ATTACH_D]] : !fir.ref>, !fir.ref>, !fir.ptr){{$}} !$acc enter data create(a) async !CHECK: %[[CREATE_A:.*]] = acc.create varPtr(%[[A]] : !fir.ref>) -> !fir.ref> {name = "a", structured = false} @@ -348,3 +352,100 @@ end subroutine +subroutine acc_enter_data_allocatable() + real, allocatable :: a(:) + integer, allocatable :: i + +!CHECK-LABEL: func.func @_QPacc_enter_data_allocatable() { +!CHECK: %[[A:.*]] = fir.alloca !fir.box>> {bindc_name = "a", uniq_name = "_QFacc_enter_data_allocatableEa"} +!CHECK: %[[I:.*]] = fir.alloca !fir.box> {bindc_name = "i", uniq_name = "_QFacc_enter_data_allocatableEi"} + + !$acc enter data create(a) +!CHECK: %[[BOX_A_0:.*]] = fir.load %[[A]] : !fir.ref>>> +!CHECK: %[[C0_0:.*]] = arith.constant 0 : index +!CHECK: %[[BOX_A_1:.*]] = fir.load %[[A]] : !fir.ref>>> +!CHECK: %[[C0_1:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[BOX_A_1]], %[[C0_1]] : (!fir.box>>, index) -> (index, index, index) +!CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[BOX_A_0]], %[[C0_0]] : (!fir.box>>, index) -> (index, index, index) +!CHECK: %[[BOUND:.*]] = acc.bounds extent(%[[DIMS1]]#1 : index) stride(%[[DIMS1]]#2 : index) startIdx(%[[DIMS0]]#0 : index) {strideInBytes = true} +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX_A_0]] : (!fir.box>>) -> !fir.ref>> +!CHECK: %[[ADDR:.*]] = fir.load %[[BOX_ADDR]] : !fir.ref>> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ADDR]] : !fir.heap>) bounds(%[[BOUND]]) -> !fir.heap> {name = "a", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.heap>) + + !$acc enter data create(a(:)) +!CHECK: %[[BOX_A_0:.*]] = fir.load %[[A]] : !fir.ref>>> +!CHECK: %[[BOX_A_1:.*]] = fir.load %[[A]] : !fir.ref>>> +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[BOX_A_1]], %[[C0]] : (!fir.box>>, index) -> (index, index, index) +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[BOX_A_0]], %[[C0]] : (!fir.box>>, index) -> (index, index, index) +!CHECK: %[[BOX_A_2:.*]] = fir.load %[[A]] : !fir.ref>>> +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS2:.*]]:3 = fir.box_dims %[[BOX_A_2]], %[[C0]] : (!fir.box>>, index) -> (index, index, index) +!CHECK: %[[BOUND:.*]] = acc.bounds extent(%[[DIMS2]]#1 : index) stride(%[[DIMS1]]#2 : index) startIdx(%[[DIMS0]]#0 : index) {strideInBytes = true} +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX_A_0]] : (!fir.box>>) -> !fir.ref>> +!CHECK: %[[ADDR:.*]] = fir.load %[[BOX_ADDR]] : !fir.ref>> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ADDR]] : !fir.heap>) bounds(%[[BOUND]]) -> !fir.heap> {name = "a(:)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.heap>) + + !$acc enter data create(a(2:5)) +!CHECK: %[[BOX_A_0:.*]] = fir.load %[[A]] : !fir.ref>>> +!CHECK: %[[BOX_A_1:.*]] = fir.load %[[A]] : !fir.ref>>> +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[BOX_A_1]], %[[C0]] : (!fir.box>>, index) -> (index, index, index) +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[BOX_A_0]], %[[C0]] : (!fir.box>>, index) -> (index, index, index) +!CHECK: %[[C2:.*]] = arith.constant 2 : index +!CHECK: %[[LB:.*]] = arith.subi %[[C2]], %[[DIMS0]]#0 : index +!CHECK: %[[C5:.*]] = arith.constant 5 : index +!CHECK: %[[UB:.*]] = arith.subi %[[C5]], %[[DIMS0]]#0 : index +!CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) stride(%[[DIMS1]]#2 : index) startIdx(%[[DIMS0]]#0 : index) {strideInBytes = true} +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX_A_0]] : (!fir.box>>) -> !fir.ref>> +!CHECK: %[[ADDR:.*]] = fir.load %[[BOX_ADDR]] : !fir.ref>> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ADDR]] : !fir.heap>) bounds(%[[BOUND]]) -> !fir.heap> {name = "a(2:5)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.heap>) + + !$acc enter data create(a(3:)) +!CHECK: %[[BOX_A_0:.*]] = fir.load %[[A]] : !fir.ref>>> +!CHECK: %[[BOX_A_1:.*]] = fir.load %[[A]] : !fir.ref>>> +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[BOX_A_1]], %[[C0]] : (!fir.box>>, index) -> (index, index, index) +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[BOX_A_0]], %[[C0]] : (!fir.box>>, index) -> (index, index, index) +!CHECK: %[[C3:.*]] = arith.constant 3 : index +!CHECK: %[[LB:.*]] = arith.subi %[[C3]], %[[DIMS0]]#0 : index +!CHECK: %[[BOX_A_1:.*]] = fir.load %[[A]] : !fir.ref>>> +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS2:.*]]:3 = fir.box_dims %[[BOX_A_1]], %[[C0]] : (!fir.box>>, index) -> (index, index, index) +!CHECK: %[[EXT:.*]] = arith.subi %[[DIMS2]]#1, %[[LB]] : index +!CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[LB]] : index) extent(%[[EXT]] : index) stride(%[[DIMS1]]#2 : index) startIdx(%[[DIMS0]]#0 : index) {strideInBytes = true} +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX_A_0]] : (!fir.box>>) -> !fir.ref>> +!CHECK: %[[ADDR:.*]] = fir.load %[[BOX_ADDR]] : !fir.ref>> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ADDR]] : !fir.heap>) bounds(%[[BOUND]]) -> !fir.heap> {name = "a(3:)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.heap>) + + !$acc enter data create(a(:7)) +!CHECK: %[[BOX_A_0:.*]] = fir.load %[[A]] : !fir.ref>>> +!CHECK: %[[BOX_A_1:.*]] = fir.load %[[A]] : !fir.ref>>> +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[BOX_A_1]], %[[C0]] : (!fir.box>>, index) -> (index, index, index) +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[BOX_A_0]], %[[C0]] : (!fir.box>>, index) -> (index, index, index) +!CHECK: %[[C7:.*]] = arith.constant 7 : index +!CHECK: %[[UB:.*]] = arith.subi %[[C7]], %[[DIMS0]]#0 : index +!CHECK: %[[BOUND:.*]] = acc.bounds upperbound(%[[UB]] : index) stride(%[[DIMS1]]#2 : index) startIdx(%[[DIMS0]]#0 : index) {strideInBytes = true} +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX_A_0]] : (!fir.box>>) -> !fir.ref>> +!CHECK: %[[ADDR:.*]] = fir.load %[[BOX_ADDR]] : !fir.ref>> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ADDR]] : !fir.heap>) bounds(%[[BOUND]]) -> !fir.heap> {name = "a(:7)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.heap>) + + !$acc enter data create(i) +!CHECK: %[[BOX_I:.*]] = fir.load %[[I]] : !fir.ref>> +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX_I]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[ADDR:.*]] = fir.load %[[BOX_ADDR]] : !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ADDR]] : !fir.heap) -> !fir.heap {name = "i", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.heap) + +end subroutine +