diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -101,6 +101,26 @@ } } +static llvm::SmallVector +genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value box, int rank) { + llvm::SmallVector bounds; + mlir::Type idxTy = builder.getIndexType(); + mlir::Type boundTy = builder.getType(); + assert(box.getType().isa() && "expect firbox or fir.class"); + for (int dim = 0; dim < rank; ++dim) { + mlir::Value d = builder.createIntegerConstant(loc, idxTy, dim); + auto dimInfo = + builder.create(loc, idxTy, idxTy, idxTy, box, d); + mlir::Value empty; + mlir::Value bound = builder.create( + loc, boundTy, dimInfo.getLowerBound(), empty, dimInfo.getExtent(), + dimInfo.getByteStride(), true, empty); + bounds.push_back(bound); + } + return bounds; +} + static llvm::SmallVector genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, Fortran::lower::AbstractConverter &converter, @@ -205,9 +225,8 @@ fir::FirOpBuilder &builder = converter.getFirOpBuilder(); - auto createOpAndAddOperand = [&](Fortran::lower::SymbolRef sym, - llvm::StringRef name, - mlir::Location loc) -> Op { + auto getDataOperandBaseAddr = [&](Fortran::lower::SymbolRef sym, + mlir::Location loc) -> mlir::Value { mlir::Value symAddr = converter.getSymbolAddress(sym); // TODO: Might need revisiting to handle for non-shared clauses if (!symAddr) { @@ -219,14 +238,28 @@ if (!symAddr) llvm::report_fatal_error("could not retrieve symbol address"); - if (symAddr.getType().isa()) - TODO(loc, "data operand operation creation for box types"); - Op op = builder.create(loc, symAddr.getType(), symAddr); + if (auto box = symAddr.getType().dyn_cast()) + if (!box.getEleTy().isa()) + TODO(loc, "pointer, allocatable and derived type box"); + + return symAddr; + }; + + auto createOpAndAddOperand = [&](mlir::Value baseAddr, + llvm::StringRef name, + mlir::Location loc, + llvm::SmallVector &bounds) { + if (baseAddr.getType().isa()) + baseAddr = builder.create(loc, baseAddr); + + Op op = builder.create(loc, baseAddr.getType(), baseAddr); op.setNameAttr(builder.getStringAttr(name)); op.setStructured(structured); op.setDataClause(dataClause); + if (bounds.size() > 0) + op->insertOperands(1, bounds); op->setAttr(Op::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr({1, 0, 0})); + builder.getDenseI32ArrayAttr({1, 0, static_cast(bounds.size())})); dataOperands.push_back(op.getAccPtr()); return op; }; @@ -259,12 +292,9 @@ arrayElement->subscripts, asFortran, name); } asFortran << ')'; - Op op = createOpAndAddOperand(*name.symbol, asFortran.str(), - operandLocation); - op->insertOperands(1, bounds); - op->setAttr(Op::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr( - {1, 0, static_cast(bounds.size())})); + mlir::Value baseAddr = getDataOperandBaseAddr(*name.symbol, operandLocation); + createOpAndAddOperand(baseAddr, asFortran.str(), + operandLocation, bounds); } else if (Fortran::parser::Unwrap< Fortran::parser::StructureComponent>( designator)) { @@ -275,8 +305,12 @@ &designator.u)}) { const Fortran::parser::Name &name = Fortran::parser::GetLastName(*dataRef); - createOpAndAddOperand(*name.symbol, name.ToString(), - operandLocation); + mlir::Value baseAddr = getDataOperandBaseAddr(*name.symbol, operandLocation); + llvm::SmallVector bounds; + if (baseAddr.getType().isa()) + bounds = genBoundsOpsFromBox(builder, operandLocation, baseAddr, (*expr).Rank()); + createOpAndAddOperand(baseAddr, name.ToString(), + operandLocation, bounds); } else { // Unsupported llvm::report_fatal_error( "Unsupported type of OpenACC operand"); @@ -287,8 +321,10 @@ [&](const Fortran::parser::Name &name) { mlir::Location operandLocation = converter.genLocation(name.source); - createOpAndAddOperand(*name.symbol, name.ToString(), - operandLocation); + mlir::Value baseAddr = getDataOperandBaseAddr(*name.symbol, operandLocation); + llvm::SmallVector bounds; + createOpAndAddOperand(baseAddr, name.ToString(), + operandLocation, bounds); }}, accObject.u); } diff --git a/flang/test/Lower/OpenACC/acc-enter-data.f90 b/flang/test/Lower/OpenACC/acc-enter-data.f90 --- a/flang/test/Lower/OpenACC/acc-enter-data.f90 +++ b/flang/test/Lower/OpenACC/acc-enter-data.f90 @@ -219,3 +219,94 @@ !CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) end subroutine + +! Test lowering of assumed size arrays. +subroutine acc_enter_data_assumed(a, n, m) + integer :: n, m + real :: a(:) + +!CHECK-LABEL: func.func @_QPacc_enter_data_assumed( +!CHECK-SAME: %[[A:.*]]: !fir.box> {fir.bindc_name = "a"}, %[[N:.*]]: !fir.ref {fir.bindc_name = "n"}, %[[M:.*]]: !fir.ref {fir.bindc_name = "m"}) { + + !$acc enter data create(a) +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[A]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +!CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[DIMS]]#0 : index) extent(%[[DIMS]]#1 : index) stride(%[[DIMS]]#2 : index) {strideInBytes = true} +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[A]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[BOX_ADDR]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(a(:)) +!CHECK: %[[ONE:.*]] = arith.constant 1 : index +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[A]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +!CHECK: %[[BOUND:.*]] = acc.bounds extent(%[[DIMS]]#1 : index) startIdx(%[[ONE]] : index) +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[A]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[BOX_ADDR]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a(:)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(a(2:)) +!CHECK: %[[SIDX:.*]] = arith.constant 1 : index +!CHECK: %[[LB:.*]] = arith.constant 1 : index +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[A]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +!CHECK: %[[EXT:.*]] = arith.subi %[[DIMS]]#1, %[[LB]] : index +!CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[LB]] : index) extent(%[[EXT]] : index) startIdx(%[[SIDX]] : index) +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[A]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[BOX_ADDR]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a(2:)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(a(:4)) +!CHECK: %[[ONE:.*]] = arith.constant 1 : index +!CHECK: %[[UB:.*]] = arith.constant 3 : index +!CHECK: %[[BOUND:.*]] = acc.bounds upperbound(%[[UB]] : index) startIdx(%[[ONE]] : index) +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[A]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[BOX_ADDR]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a(:4)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(a(6:10)) +!CHECK: %[[ONE:.*]] = arith.constant 1 : index +!CHECK: %[[LB:.*]] = arith.constant 5 : index +!CHECK: %[[UB:.*]] = arith.constant 9 : index +!CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) startIdx(%[[ONE]] : index) +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[A]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[BOX_ADDR]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a(6:10)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(a(n:)) +!CHECK: %[[ONE:.*]] = arith.constant 1 : index +!CHECK: %[[LOAD_N:.*]] = fir.load %[[N]] : !fir.ref +!CHECK: %[[CONVERT_N:.*]] = fir.convert %[[LOAD_N]] : (i32) -> index +!CHECK: %[[LB:.*]] = arith.subi %[[CONVERT_N]], %[[ONE]] : index +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[A]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +!CHECK: %[[EXT:.*]] = arith.subi %[[DIMS]]#1, %[[LB]] : index +!CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[LB]] : index) extent(%[[EXT]] : index) startIdx(%[[ONE]] : index) +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[A]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[BOX_ADDR]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a(n:)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(a(:m)) +!CHECK: %[[ONE:.*]] = arith.constant 1 : index +!CHECK: %[[LOAD_M:.*]] = fir.load %[[M]] : !fir.ref +!CHECK: %[[CONVERT_M:.*]] = fir.convert %[[LOAD_M]] : (i32) -> index +!CHECK: %[[UB:.*]] = arith.subi %[[CONVERT_M]], %[[ONE]] : index +!CHECK: %[[BOUND:.*]] = acc.bounds upperbound(%[[UB]] : index) startIdx(%[[ONE]] : index) +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[A]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%31 : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a(:m)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(a(n:m)) +!CHECK: %[[ONE:.*]] = arith.constant 1 : index +!CHECK: %[[LOAD_N:.*]] = fir.load %[[N]] : !fir.ref +!CHECK: %[[CONVERT_N:.*]] = fir.convert %[[LOAD_N]] : (i32) -> index +!CHECK: %[[LB:.*]] = arith.subi %[[CONVERT_N]], %[[ONE]] : index +!CHECK: %[[LOAD_M:.*]] = fir.load %[[M]] : !fir.ref +!CHECK: %[[CONVERT_M:.*]] = fir.convert %[[LOAD_M]] : (i32) -> index +!CHECK: %[[UB:.*]] = arith.subi %[[CONVERT_M]], %[[ONE]] : index +!CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) startIdx(%[[ONE]] : index) +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[A]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[BOX_ADDR]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a(n:m)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + +end subroutine