diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -132,10 +132,9 @@ static llvm::SmallVector genBaseBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, Fortran::lower::AbstractConverter &converter, - const Fortran::parser::Name &name, mlir::Value baseAddr) { + fir::ExtendedValue dataExv, mlir::Value baseAddr) { mlir::Type idxTy = builder.getIndexType(); mlir::Type boundTy = builder.getType(); - fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(*name.symbol); llvm::SmallVector bounds; if (dataExv.rank() == 0) @@ -160,13 +159,12 @@ Fortran::lower::AbstractConverter &converter, Fortran::lower::StatementContext &stmtCtx, const std::list &subscripts, - std::stringstream &asFortran, const Fortran::parser::Name &name, + std::stringstream &asFortran, fir::ExtendedValue &dataExv, mlir::Value baseAddr) { int dimension = 0; mlir::Type idxTy = builder.getIndexType(); mlir::Type boundTy = builder.getType(); llvm::SmallVector bounds; - fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(*name.symbol); for (const auto &subscript : subscripts) { if (const auto *triplet{ @@ -260,6 +258,29 @@ return bounds; } +/// Walk back the IR to find the fir.coordinate operation and extract the base +/// information. +static mlir::Value findStructAddr(fir::FirOpBuilder builder, mlir::Location loc, + mlir::Operation *op) { + if (auto coord = mlir::dyn_cast_or_null(op)) { + std::size_t coordNb = coord.getCoor().size(); + if (coordNb == 1) + return coord.getRef(); + if (auto field = mlir::dyn_cast_or_null( + coord.getCoor()[coordNb - 2].getDefiningOp())) { + mlir::Type baseTy = fir::ReferenceType::get(field.getOnType()); + return builder.create( + loc, baseTy, coord.getRef(), coord.getCoor().take_front(coordNb - 1)); + } + return mlir::Value(); + } + if (auto boxAddrOp = mlir::dyn_cast_or_null(op)) + return findStructAddr(builder, loc, boxAddrOp.getVal().getDefiningOp()); + if (auto loadOp = mlir::dyn_cast_or_null(op)) + return findStructAddr(builder, loc, loadOp.getMemref().getDefiningOp()); + return mlir::Value(); +} + template static void genDataOperandOperations(const Fortran::parser::AccObjectList &objectList, @@ -297,8 +318,8 @@ return symAddr; }; - auto createOpAndAddOperand = [&](mlir::Value baseAddr, llvm::StringRef name, - mlir::Location loc, + auto createOpAndAddOperand = [&](mlir::Value baseAddr, mlir::Value structAddr, + llvm::StringRef name, mlir::Location loc, llvm::SmallVector &bounds) { if (auto boxTy = baseAddr.getType().dyn_cast()) { // Get the actual data address when the descriptor is an allocatable or @@ -316,11 +337,15 @@ op.setNameAttr(builder.getStringAttr(name)); op.setStructured(structured); op.setDataClause(dataClause); + unsigned insPos = 1; + if (structAddr) + op->insertOperands(insPos++, structAddr); if (bounds.size() > 0) - op->insertOperands(1, bounds); - op->setAttr(Op::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr( - {1, 0, static_cast(bounds.size())})); + op->insertOperands(insPos, bounds); + op->setAttr( + Op::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr( + {1, structAddr ? 1 : 0, static_cast(bounds.size())})); dataOperands.push_back(op.getAccPtr()); return op; }; @@ -342,31 +367,73 @@ llvm::SmallVector bounds; const auto *dataRef = std::get_if(&designator.u); - const Fortran::parser::Name &name = - Fortran::parser::GetLastName(*dataRef); + mlir::Value addr; + mlir::Value baseAddr; std::stringstream asFortran; - asFortran << name.ToString(); - mlir::Value baseAddr = - getDataOperandBaseAddr(*name.symbol, operandLocation); + fir::ExtendedValue dataExv; + if (Fortran::parser::Unwrap< + Fortran::parser::StructureComponent>( + arrayElement->base)) { + auto exprBase = Fortran::semantics::AnalyzeExpr( + semanticsContext, arrayElement->base); + // 2.7.1. Any array or subarray in a data clause, including + // Fortran array pointers, must be a contiguous section of + // memory + dataExv = converter.genExprAddr(operandLocation, *exprBase, + stmtCtx); + addr = fir::getBase(dataExv); + baseAddr = findStructAddr(builder, operandLocation, + addr.getDefiningOp()); + asFortran << (*exprBase).AsFortran(); + } else { + const Fortran::parser::Name &name = + Fortran::parser::GetLastName(*dataRef); + addr = + getDataOperandBaseAddr(*name.symbol, operandLocation); + dataExv = converter.getSymbolExtendedValue(*name.symbol); + asFortran << name.ToString(); + } if (!arrayElement->subscripts.empty()) { asFortran << '('; bounds = genBoundsOps(builder, operandLocation, converter, stmtCtx, arrayElement->subscripts, - asFortran, name, baseAddr); + asFortran, dataExv, addr); } asFortran << ')'; - createOpAndAddOperand(baseAddr, asFortran.str(), + createOpAndAddOperand(addr, baseAddr, asFortran.str(), operandLocation, bounds); } else if (Fortran::parser::Unwrap< Fortran::parser::StructureComponent>( designator)) { - TODO(operandLocation, "OpenACC derived-type data operand"); + fir::ExtendedValue compExv = + converter.genExprAddr(operandLocation, *expr, stmtCtx); + mlir::Value compAddr = fir::getBase(compExv); + mlir::Value structAddr = findStructAddr( + builder, operandLocation, compAddr.getDefiningOp()); + llvm::SmallVector bounds; + if (fir::unwrapRefType(compAddr.getType()) + .isa()) + bounds = genBaseBoundsOps(builder, operandLocation, + converter, compExv, compAddr); + + // If the component is an allocatable or pointer the result of + // genExprAddr will be the result of a fir.box_addr operation. + // Retrieve the box so we handle it like other descriptor. + if (auto boxAddrOp = mlir::dyn_cast_or_null( + compAddr.getDefiningOp())) + compAddr = boxAddrOp.getVal(); + + createOpAndAddOperand(compAddr, structAddr, + (*expr).AsFortran(), operandLocation, + bounds); } else { // Scalar or full array. if (const auto *dataRef{std::get_if( &designator.u)}) { const Fortran::parser::Name &name = Fortran::parser::GetLastName(*dataRef); + fir::ExtendedValue dataExv = + converter.getSymbolExtendedValue(*name.symbol); mlir::Value baseAddr = getDataOperandBaseAddr(*name.symbol, operandLocation); llvm::SmallVector bounds; @@ -375,12 +442,13 @@ bounds = genBoundsOpsFromBox(builder, operandLocation, converter, *name.symbol, baseAddr, (*expr).Rank()); - if (fir::unwrapRefType(baseAddr.getType()) - .isa()) + else if (fir::unwrapRefType(baseAddr.getType()) + .isa()) bounds = genBaseBoundsOps(builder, operandLocation, - converter, name, baseAddr); - createOpAndAddOperand(baseAddr, name.ToString(), - operandLocation, bounds); + converter, dataExv, baseAddr); + createOpAndAddOperand(baseAddr, mlir::Value(), + name.ToString(), operandLocation, + bounds); } else { // Unsupported llvm::report_fatal_error( "Unsupported type of OpenACC operand"); @@ -394,8 +462,8 @@ mlir::Value baseAddr = getDataOperandBaseAddr(*name.symbol, operandLocation); llvm::SmallVector bounds; - createOpAndAddOperand(baseAddr, name.ToString(), operandLocation, - bounds); + createOpAndAddOperand(baseAddr, mlir::Value(), name.ToString(), + operandLocation, bounds); }}, accObject.u); } diff --git a/flang/test/Lower/OpenACC/acc-enter-data.f90 b/flang/test/Lower/OpenACC/acc-enter-data.f90 --- a/flang/test/Lower/OpenACC/acc-enter-data.f90 +++ b/flang/test/Lower/OpenACC/acc-enter-data.f90 @@ -538,3 +538,114 @@ end subroutine +subroutine acc_enter_data_derived_type() + type :: dt + real :: data + real :: array(1:10) + end type + + type :: t + type(dt) :: d + end type + + type :: z + integer, allocatable :: data(:) + end type + + type(dt) :: a + type(t) :: b + type(dt) :: aa(10) + type(z) :: c + +!CHECK-LABEL: func.func @_QPacc_enter_data_derived_type() { +!CHECK: %[[A:.*]] = fir.alloca !fir.type<_QFacc_enter_data_derived_typeTdt{data:f32,array:!fir.array<10xf32>}> {bindc_name = "a", uniq_name = "_QFacc_enter_data_derived_typeEa"} +!CHECK: %[[AA:.*]] = fir.alloca !fir.array<10x!fir.type<_QFacc_enter_data_derived_typeTdt{data:f32,array:!fir.array<10xf32>}>> {bindc_name = "aa", uniq_name = "_QFacc_enter_data_derived_typeEaa"} +!CHECK: %[[B:.*]] = fir.alloca !fir.type<_QFacc_enter_data_derived_typeTt{d:!fir.type<_QFacc_enter_data_derived_typeTdt{data:f32,array:!fir.array<10xf32>}>}> {bindc_name = "b", uniq_name = "_QFacc_enter_data_derived_typeEb"} +!CHECK: %[[C:.*]] = fir.alloca !fir.type<_QFacc_enter_data_derived_typeTz{data:!fir.box>>}> {bindc_name = "c", uniq_name = "_QFacc_enter_data_derived_typeEc"} + + !$acc enter data create(a%data) +!CHECK: %[[DATA_FIELD:.*]] = fir.field_index data, !fir.type<_QFacc_enter_data_derived_typeTdt{data:f32,array:!fir.array<10xf32>}> +!CHECK: %[[DATA_COORD:.*]] = fir.coordinate_of %[[A]], %[[DATA_FIELD]] : (!fir.ref}>>, !fir.field) -> !fir.ref +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[DATA_COORD]] : !fir.ref) varPtrPtr(%[[A]] : !fir.ref}>>) -> !fir.ref {name = "a%data", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref) + + !$acc enter data create(b%d%data) +!CHECK: %[[D_FIELD:.*]] = fir.field_index d, !fir.type<_QFacc_enter_data_derived_typeTt{d:!fir.type<_QFacc_enter_data_derived_typeTdt{data:f32,array:!fir.array<10xf32>}>}> +!CHECK: %[[DATA_FIELD:.*]] = fir.field_index data, !fir.type<_QFacc_enter_data_derived_typeTdt{data:f32,array:!fir.array<10xf32>}> +!CHECK: %[[DATA_COORD:.*]] = fir.coordinate_of %[[B]], %[[D_FIELD]], %[[DATA_FIELD]] : (!fir.ref}>}>>, !fir.field, !fir.field) -> !fir.ref +!CHECK: %[[D_COORD:.*]] = fir.coordinate_of %[[B]], %[[D_FIELD]] : (!fir.ref}>}>>, !fir.field) -> !fir.ref}>}>> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[DATA_COORD]] : !fir.ref) varPtrPtr(%[[D_COORD]] : !fir.ref}>}>>) -> !fir.ref {name = "b%d%data", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref) + + !$acc enter data create(a%array) +!CHECK: %[[ARRAY_FIELD:.*]] = fir.field_index array, !fir.type<_QFacc_enter_data_derived_typeTdt{data:f32,array:!fir.array<10xf32>}> +!CHECK: %[[ARRAY_COORD:.*]] = fir.coordinate_of %[[A]], %[[ARRAY_FIELD]] : (!fir.ref}>>, !fir.field) -> !fir.ref> +!CHECK: %[[C10:.*]] = arith.constant 10 : index +!CHECK: %[[C1:.*]] = arith.constant 1 : index +!CHECK: %[[BOUND:.*]] = acc.bounds extent(%[[C10]] : index) startIdx(%[[C1]] : index) +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ARRAY_COORD]] : !fir.ref>) varPtrPtr(%[[A]] : !fir.ref}>>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a%array", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(a%array(:)) +!CHECK: %[[ARRAY_FIELD:.*]] = fir.field_index array, !fir.type<_QFacc_enter_data_derived_typeTdt{data:f32,array:!fir.array<10xf32>}> +!CHECK: %[[ARRAY_COORD:.*]] = fir.coordinate_of %[[A]], %[[ARRAY_FIELD]] : (!fir.ref}>>, !fir.field) -> !fir.ref> +!CHECK: %[[C10:.*]] = arith.constant 10 : index +!CHECK: %[[C1:.*]] = arith.constant 1 : index +!CHECK: %[[BOUND:.*]] = acc.bounds extent(%[[C10]] : index) startIdx(%[[C1]] : index) +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ARRAY_COORD]] : !fir.ref>) varPtrPtr(%[[A]] : !fir.ref}>>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a%array(:)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(a%array(1:5)) +!CHECK: %[[ARRAY_FIELD:.*]] = fir.field_index array, !fir.type<_QFacc_enter_data_derived_typeTdt{data:f32,array:!fir.array<10xf32>}> +!CHECK: %[[ARRAY_COORD:.*]] = fir.coordinate_of %[[A]], %[[ARRAY_FIELD]] : (!fir.ref}>>, !fir.field) -> !fir.ref> +!CHECK: %[[C1:.*]] = arith.constant 1 : index +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[C4:.*]] = arith.constant 4 : index +!CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[C0]] : index) upperbound(%[[C4]] : index) startIdx(%[[C1]] : index) +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ARRAY_COORD]] : !fir.ref>) varPtrPtr(%[[A]] : !fir.ref}>>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a%array(1:5)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(a%array(:5)) +!CHECK: %[[ARRAY_FIELD:.*]] = fir.field_index array, !fir.type<_QFacc_enter_data_derived_typeTdt{data:f32,array:!fir.array<10xf32>}> +!CHECK: %[[ARRAY_COORD:.*]] = fir.coordinate_of %[[A]], %[[ARRAY_FIELD]] : (!fir.ref}>>, !fir.field) -> !fir.ref> +!CHECK: %[[C1:.*]] = arith.constant 1 : index +!CHECK: %[[C4:.*]] = arith.constant 4 : index +!CHECK: %[[BOUND:.*]] = acc.bounds upperbound(%[[C4]] : index) startIdx(%[[C1]] : index) +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ARRAY_COORD]] : !fir.ref>) varPtrPtr(%[[A]] : !fir.ref}>>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a%array(:5)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(a%array(2:)) +!CHECK: %[[ARRAY_FIELD:.*]] = fir.field_index array, !fir.type<_QFacc_enter_data_derived_typeTdt{data:f32,array:!fir.array<10xf32>}> +!CHECK: %[[ARRAY_COORD:.*]] = fir.coordinate_of %[[A]], %[[ARRAY_FIELD]] : (!fir.ref}>>, !fir.field) -> !fir.ref> +!CHECK: %[[C10:.*]] = arith.constant 10 : index +!CHECK: %[[STARTIDX:.*]] = arith.constant 1 : index +!CHECK: %[[LB:.*]] = arith.constant 1 : index +!CHECK: %[[EXT:.*]] = arith.subi %c10_5, %c1_7 : index +!CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[LB]] : index) extent(%[[EXT]] : index) startIdx(%[[STARTIDX]] : index) +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ARRAY_COORD]] : !fir.ref>) varPtrPtr(%[[A]] : !fir.ref}>>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a%array(2:)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + +!$acc enter data create(b%d%array) +!CHECK: %[[D_FIELD:.*]] = fir.field_index d, !fir.type<_QFacc_enter_data_derived_typeTt{d:!fir.type<_QFacc_enter_data_derived_typeTdt{data:f32,array:!fir.array<10xf32>}>}> +!CHECK: %[[ARRAY_FIELD:.*]] = fir.field_index array, !fir.type<_QFacc_enter_data_derived_typeTdt{data:f32,array:!fir.array<10xf32>}> +!CHECK: %[[ARRAY_COORD:.*]] = fir.coordinate_of %[[B]], %[[D_FIELD]], %[[ARRAY_FIELD]] : (!fir.ref}>}>>, !fir.field, !fir.field) -> !fir.ref> +!CHECK: %[[C10:.*]] = arith.constant 10 : index +!CHECK: %[[D_COORD:.*]] = fir.coordinate_of %[[B]], %[[D_FIELD]] : (!fir.ref}>}>>, !fir.field) -> !fir.ref}>}>> +!CHECK: %[[C1:.*]] = arith.constant 1 : index +!CHECK: %[[BOUND:.*]] = acc.bounds extent(%[[C10]] : index) startIdx(%[[C1]] : index) +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ARRAY_COORD]] : !fir.ref>) varPtrPtr(%[[D_COORD]] : !fir.ref}>}>>) bounds(%[[BOUND]]) -> !fir.ref> {name = "b%d%array", structured = false} + + !$acc enter data create(c%data) +!CHECK: %[[DATA_FIELD:.*]] = fir.field_index data, !fir.type<_QFacc_enter_data_derived_typeTz{data:!fir.box>>}> +!CHECK: %[[DATA_COORD:.*]] = fir.coordinate_of %3, %44 : (!fir.ref>>}>>, !fir.field) -> !fir.ref>>> +!CHECK: %[[DATA_BOX:.*]] = fir.load %[[DATA_COORD]] : !fir.ref>>> +!CHECK: %[[DIM0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[DATA_BOX]], %[[DIM0]] : (!fir.box>>, index) -> (index, index, index) +!CHECK: %[[BOUND:.*]] = acc.bounds extent(%[[DIMS0]]#1 : index) startIdx(%[[DIMS0]]#0 : index) +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[DATA_BOX]] : (!fir.box>>) -> !fir.ref>> +!CHECK: %[[ADDR:.*]] = fir.load %[[BOX_ADDR]] : !fir.ref>> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[ADDR]] : !fir.heap>) varPtrPtr(%[[C]] : !fir.ref>>}>>) bounds(%[[BOUND]]) -> !fir.heap> {name = "c%data", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.heap>) + + +end subroutine