diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -101,6 +101,31 @@ } } +static llvm::SmallVector +genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc, + Fortran::lower::AbstractConverter &converter, + Fortran::lower::SymbolRef sym, mlir::Value box, int rank) { + llvm::SmallVector bounds; + mlir::Type idxTy = builder.getIndexType(); + fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym); + mlir::Type boundTy = builder.getType(); + mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); + assert(box.getType().isa() && "expect firbox or fir.class"); + for (int dim = 0; dim < rank; ++dim) { + mlir::Value d = builder.createIntegerConstant(loc, idxTy, dim); + mlir::Value baseLb = + fir::factory::readLowerBound(builder, loc, dataExv, dim, one); + auto dimInfo = + builder.create(loc, idxTy, idxTy, idxTy, box, d); + mlir::Value empty; + mlir::Value bound = builder.create( + loc, boundTy, empty, empty, dimInfo.getExtent(), + dimInfo.getByteStride(), true, baseLb); + bounds.push_back(bound); + } + return bounds; +} + static llvm::SmallVector genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, Fortran::lower::AbstractConverter &converter, @@ -126,6 +151,15 @@ mlir::Value baseLb = fir::factory::readLowerBound(builder, loc, dataExv, dimension, one); bool defaultLb = baseLb == one; + mlir::Value stride; + bool strideInBytes = false; + if (fir::getBase(dataExv).getType().isa()) { + mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension); + auto dimInfo = builder.create(loc, idxTy, idxTy, idxTy, + fir::getBase(dataExv), d); + stride = dimInfo.getByteStride(); + strideInBytes = true; + } const auto &lower{std::get<0>(triplet->t)}; if (lower) { @@ -188,9 +222,8 @@ if (lbound) extent = builder.create(loc, extent, lbound); } - mlir::Value empty; mlir::Value bound = builder.create( - loc, boundTy, lbound, ubound, extent, empty, false, baseLb); + loc, boundTy, lbound, ubound, extent, stride, strideInBytes, baseLb); bounds.push_back(bound); ++dimension; } @@ -209,9 +242,8 @@ fir::FirOpBuilder &builder = converter.getFirOpBuilder(); - auto createOpAndAddOperand = [&](Fortran::lower::SymbolRef sym, - llvm::StringRef name, - mlir::Location loc) -> Op { + auto getDataOperandBaseAddr = [&](Fortran::lower::SymbolRef sym, + mlir::Location loc) -> mlir::Value { mlir::Value symAddr = converter.getSymbolAddress(sym); // TODO: Might need revisiting to handle for non-shared clauses if (!symAddr) { @@ -223,14 +255,33 @@ if (!symAddr) llvm::report_fatal_error("could not retrieve symbol address"); - if (symAddr.getType().isa()) - TODO(loc, "data operand operation creation for box types"); - Op op = builder.create(loc, symAddr.getType(), symAddr); + mlir::Type symTy = symAddr.getType(); + if (auto refTy = symTy.dyn_cast()) + symTy = refTy.getEleTy(); + + if (auto boxTy = + fir::unwrapRefType(symAddr.getType()).dyn_cast()) + if (!boxTy.getEleTy().isa()) + TODO(loc, "pointer, allocatable and derived type box"); + + return symAddr; + }; + + auto createOpAndAddOperand = [&](mlir::Value baseAddr, llvm::StringRef name, + mlir::Location loc, + llvm::SmallVector &bounds) { + if (baseAddr.getType().isa()) + baseAddr = builder.create(loc, baseAddr); + + Op op = builder.create(loc, baseAddr.getType(), baseAddr); op.setNameAttr(builder.getStringAttr(name)); op.setStructured(structured); op.setDataClause(dataClause); + if (bounds.size() > 0) + op->insertOperands(1, bounds); op->setAttr(Op::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr({1, 0, 0})); + builder.getDenseI32ArrayAttr( + {1, 0, static_cast(bounds.size())})); dataOperands.push_back(op.getAccPtr()); return op; }; @@ -263,12 +314,10 @@ asFortran, name); } asFortran << ')'; - Op op = createOpAndAddOperand(*name.symbol, asFortran.str(), - operandLocation); - op->insertOperands(1, bounds); - op->setAttr(Op::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr( - {1, 0, static_cast(bounds.size())})); + mlir::Value baseAddr = + getDataOperandBaseAddr(*name.symbol, operandLocation); + createOpAndAddOperand(baseAddr, asFortran.str(), + operandLocation, bounds); } else if (Fortran::parser::Unwrap< Fortran::parser::StructureComponent>( designator)) { @@ -279,8 +328,15 @@ &designator.u)}) { const Fortran::parser::Name &name = Fortran::parser::GetLastName(*dataRef); - createOpAndAddOperand(*name.symbol, name.ToString(), - operandLocation); + mlir::Value baseAddr = + getDataOperandBaseAddr(*name.symbol, operandLocation); + llvm::SmallVector bounds; + if (baseAddr.getType().isa()) + bounds = genBoundsOpsFromBox(builder, operandLocation, + converter, *name.symbol, + baseAddr, (*expr).Rank()); + createOpAndAddOperand(baseAddr, name.ToString(), + operandLocation, bounds); } else { // Unsupported llvm::report_fatal_error( "Unsupported type of OpenACC operand"); @@ -291,8 +347,11 @@ [&](const Fortran::parser::Name &name) { mlir::Location operandLocation = converter.genLocation(name.source); - createOpAndAddOperand(*name.symbol, name.ToString(), - operandLocation); + mlir::Value baseAddr = + getDataOperandBaseAddr(*name.symbol, operandLocation); + llvm::SmallVector bounds; + createOpAndAddOperand(baseAddr, name.ToString(), operandLocation, + bounds); }}, accObject.u); } diff --git a/flang/test/Lower/OpenACC/acc-enter-data.f90 b/flang/test/Lower/OpenACC/acc-enter-data.f90 --- a/flang/test/Lower/OpenACC/acc-enter-data.f90 +++ b/flang/test/Lower/OpenACC/acc-enter-data.f90 @@ -219,3 +219,133 @@ !CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) end subroutine + +! Test lowering of assumed size arrays. +subroutine acc_enter_data_assumed(a, b, n, m) + integer :: n, m + real :: a(:) + real :: b(10:) + +!CHECK-LABEL: func.func @_QPacc_enter_data_assumed( +!CHECK-SAME: %[[A:.*]]: !fir.box> {fir.bindc_name = "a"}, %[[B:.*]]: !fir.box> {fir.bindc_name = "b"}, %[[N:.*]]: !fir.ref {fir.bindc_name = "n"}, %[[M:.*]]: !fir.ref {fir.bindc_name = "m"}) { + +!CHECK: %[[LB_C10:.*]] = arith.constant 10 : i64 +!CHECK: %[[LB_C10_IDX:.*]] = fir.convert %[[LB_C10]] : (i64) -> index + + !$acc enter data create(a) +!CHECK: %[[C1:.*]] = arith.constant 1 : index +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[A]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +!CHECK: %[[BOUND:.*]] = acc.bounds extent(%[[DIMS]]#1 : index) stride(%[[DIMS]]#2 : index) startIdx(%[[C1]] : index) {strideInBytes = true} +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[A]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[BOX_ADDR]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(a(:)) +!CHECK: %[[ONE:.*]] = arith.constant 1 : index +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[A]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[A]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +!CHECK: %[[BOUND:.*]] = acc.bounds extent(%[[DIMS1]]#1 : index) stride(%[[DIMS0]]#2 : index) startIdx(%[[ONE]] : index) {strideInBytes = true} +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[A]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[BOX_ADDR]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a(:)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(a(2:)) +!CHECK: %[[SIDX:.*]] = arith.constant 1 : index +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[A]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +!CHECK: %[[LB:.*]] = arith.constant 1 : index +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[A]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +!CHECK: %[[EXT:.*]] = arith.subi %[[DIMS1]]#1, %[[LB]] : index +!CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[LB]] : index) extent(%[[EXT]] : index) stride(%[[DIMS0]]#2 : index) startIdx(%[[SIDX]] : index) {strideInBytes = true} +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[A]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[BOX_ADDR]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a(2:)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(a(:4)) +!CHECK: %[[ONE:.*]] = arith.constant 1 : index +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[A]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +!CHECK: %[[UB:.*]] = arith.constant 3 : index +!CHECK: %[[BOUND:.*]] = acc.bounds upperbound(%[[UB]] : index) stride(%[[DIMS0]]#2 : index) startIdx(%[[ONE]] : index) {strideInBytes = true} +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[A]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[BOX_ADDR]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a(:4)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(a(6:10)) +!CHECK: %[[ONE:.*]] = arith.constant 1 : index +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[A]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +!CHECK: %[[LB:.*]] = arith.constant 5 : index +!CHECK: %[[UB:.*]] = arith.constant 9 : index +!CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) stride(%[[DIMS0]]#2 : index) startIdx(%[[ONE]] : index) {strideInBytes = true} +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[A]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[BOX_ADDR]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a(6:10)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(a(n:)) +!CHECK: %[[ONE:.*]] = arith.constant 1 : index +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[A]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +!CHECK: %[[LOAD_N:.*]] = fir.load %[[N]] : !fir.ref +!CHECK: %[[CONVERT_N:.*]] = fir.convert %[[LOAD_N]] : (i32) -> index +!CHECK: %[[LB:.*]] = arith.subi %[[CONVERT_N]], %[[ONE]] : index +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[A]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +!CHECK: %[[EXT:.*]] = arith.subi %[[DIMS]]#1, %[[LB]] : index +!CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[LB]] : index) extent(%[[EXT]] : index) stride(%[[DIMS0]]#2 : index) startIdx(%[[ONE]] : index) {strideInBytes = true} +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[A]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[BOX_ADDR]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a(n:)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(a(:m)) +!CHECK: %[[ONE:.*]] = arith.constant 1 : index +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[A]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +!CHECK: %[[LOAD_M:.*]] = fir.load %[[M]] : !fir.ref +!CHECK: %[[CONVERT_M:.*]] = fir.convert %[[LOAD_M]] : (i32) -> index +!CHECK: %[[UB:.*]] = arith.subi %[[CONVERT_M]], %[[ONE]] : index +!CHECK: %[[BOUND:.*]] = acc.bounds upperbound(%[[UB]] : index) stride(%[[DIMS0]]#2 : index) startIdx(%[[ONE]] : index) {strideInBytes = true} +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[A]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[BOX_ADDR]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a(:m)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(a(n:m)) +!CHECK: %[[ONE:.*]] = arith.constant 1 : index +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[A]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +!CHECK: %[[LOAD_N:.*]] = fir.load %[[N]] : !fir.ref +!CHECK: %[[CONVERT_N:.*]] = fir.convert %[[LOAD_N]] : (i32) -> index +!CHECK: %[[LB:.*]] = arith.subi %[[CONVERT_N]], %[[ONE]] : index +!CHECK: %[[LOAD_M:.*]] = fir.load %[[M]] : !fir.ref +!CHECK: %[[CONVERT_M:.*]] = fir.convert %[[LOAD_M]] : (i32) -> index +!CHECK: %[[UB:.*]] = arith.subi %[[CONVERT_M]], %[[ONE]] : index +!CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) stride(%[[DIMS0]]#2 : index) startIdx(%[[ONE]] : index) {strideInBytes = true} +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[A]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[BOX_ADDR]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a(n:m)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(b(:m)) +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[B]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +!CHECK: %[[LOAD_M:.*]] = fir.load %[[M]] : !fir.ref +!CHECK: %[[CONVERT_M:.*]] = fir.convert %[[LOAD_M]] : (i32) -> index +!CHECK: %[[UB:.*]] = arith.subi %[[CONVERT_M]], %[[LB_C10_IDX]] : index +!CHECK: %[[BOUND:.*]] = acc.bounds upperbound(%[[UB]] : index) stride(%[[DIMS0]]#2 : index) startIdx(%[[LB_C10_IDX]] : index) {strideInBytes = true} +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[B]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[BOX_ADDR]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "b(:m)", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + + !$acc enter data create(b) +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[B]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +!CHECK: %[[BOUND:.*]] = acc.bounds extent(%[[DIMS0]]#1 : index) stride(%[[DIMS0]]#2 : index) startIdx(%[[LB_C10_IDX]] : index) {strideInBytes = true} +!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[B]] : (!fir.box>) -> !fir.ref> +!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[BOX_ADDR]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "b", structured = false} +!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref>) + +end subroutine +