diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -1533,6 +1533,60 @@ let verifier = "return ::verify(*this);"; } +def fir_ArrayAccessOp : fir_Op<"array_access", [AttrSizedOperandSegments, + NoSideEffect]> { + let summary = "Fetch the reference of an element of an array value"; + + let description = [{ + The `array_access` provides a reference to a single element from an array + value. This is *not* a view in the immutable array, otherwise it couldn't + be stored to. It can be see as a logical copy of the element and its + position in the array. This reference can be written to and modified without + changing the original array. + + The `array_access` operation is used to fetch the memory reference of an + element in an array value. + + ```fortran + real :: a(n,m) + ... + ... a ... + ... a(r,s+1) ... + ``` + + One can use `fir.array_access` to recover the implied memory reference to + the element `a(i,j)` in an array expression `a` as shown above. It can also + be used to recover the reference element `a(r,s+1)` in the second + expression. + + ```mlir + %s = fir.shape %n, %m : (index, index) -> !fir.shape<2> + // load the entire array 'a' + %v = fir.array_load %a(%s) : (!fir.ref>, !fir.shape<2>) -> !fir.array + // fetch the value of one of the array value's elements + %1 = fir.array_access %v, %i, %j : (!fir.array, index, index) -> !fir.ref + ``` + + More information about `array_access` and other array operations can be + found in flang/docs/FIRArrayOperations.md. + }]; + + let arguments = (ins + fir_SequenceType:$sequence, + Variadic:$indices, + Variadic:$typeparams + ); + + let results = (outs fir_ReferenceType:$element); + + let assemblyFormat = [{ + $sequence `,` $indices (`typeparams` $typeparams^)? attr-dict `:` + functional-type(operands, results) + }]; + + let verifier = "return ::verify(*this);"; +} + def fir_ArrayMergeStoreOp : fir_Op<"array_merge_store", [AttrSizedOperandSegments]> { diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -496,6 +496,24 @@ return mlir::success(); } +//===----------------------------------------------------------------------===// +// ArrayAccessOp +//===----------------------------------------------------------------------===// + +static mlir::LogicalResult verify(fir::ArrayAccessOp op) { + auto arrTy = op.sequence().getType().cast(); + std::size_t indSize = op.indices().size(); + if (indSize < arrTy.getDimension()) + return op.emitOpError("number of indices != dimension of array"); + if (indSize == arrTy.getDimension() && + op.element().getType() != fir::ReferenceType::get(arrTy.getEleTy())) + return op.emitOpError("return type does not match array"); + mlir::Type ty = validArraySubobject(op); + if (!ty || fir::ReferenceType::get(ty) != op.getType()) + return op.emitOpError("return type and/or indices do not type check"); + return mlir::success(); +} + //===----------------------------------------------------------------------===// // ArrayUpdateOp //===----------------------------------------------------------------------===// diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir --- a/flang/test/Fir/fir-ops.fir +++ b/flang/test/Fir/fir-ops.fir @@ -747,3 +747,14 @@ %1 = fir.load %0 : !fir.llvm_ptr>>> return %1 : !fir.ref>> } + +func @array_access_ops(%a : !fir.ref>) { + %c1 = arith.constant 1 : index + %n = arith.constant 0 : index + %m = arith.constant 50 : index + %s = fir.shape %n, %m : (index, index) -> !fir.shape<2> + %v = fir.array_load %a(%s) : (!fir.ref>, !fir.shape<2>) -> !fir.array + %p = fir.array_access %v, %c1, %c1 : (!fir.array, index, index) -> !fir.ref + // CHECK: %{{.*}} = fir.array_access %{{.*}}, %{{.*}}, %{{.*}} : (!fir.array, index, index) -> !fir.ref + return +} diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir --- a/flang/test/Fir/invalid.fir +++ b/flang/test/Fir/invalid.fir @@ -694,3 +694,55 @@ fir.array_merge_store %av1, %av2 to %arr1[%slice] : !fir.array, !fir.array, !fir.ref>, !fir.slice<1> return } + +// ----- + +func @array_access(%a : !fir.ref>) { + %c1 = arith.constant 1 : index + %n = arith.constant 0 : index + %m = arith.constant 50 : index + %s = fir.shape %n, %m : (index, index) -> !fir.shape<2> + %v = fir.array_load %a(%s) : (!fir.ref>, !fir.shape<2>) -> !fir.array + // expected-error@+1 {{'fir.array_access' op number of indices != dimension of array}} + %p = fir.array_access %v, %c1 : (!fir.array, index) -> !fir.ref + return +} + +// ----- + +func @array_access(%a : !fir.ref>) { + %c1 = arith.constant 1 : index + %n = arith.constant 0 : index + %m = arith.constant 50 : index + %s = fir.shape %n, %m : (index, index) -> !fir.shape<2> + %v = fir.array_load %a(%s) : (!fir.ref>, !fir.shape<2>) -> !fir.array + // expected-error@+1 {{'fir.array_access' op return type does not match array}} + %p = fir.array_access %v, %c1, %c1 : (!fir.array, index, index) -> !fir.ref + return +} + +// ----- + +func @foo(%arg0: !fir.ref}>>>) { + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c9 = arith.constant 9 : index + %c19 = arith.constant 19 : index + %c30 = arith.constant 30 : index + %0 = fir.shape %c30 : (index) -> !fir.shape<1> + %1 = fir.array_load %arg0(%0) : (!fir.ref}>>>, !fir.shape<1>) -> !fir.array<30x!fir.type}>> + %2 = fir.do_loop %arg1 = %c1 to %c9 step %c1 unordered iter_args(%arg2 = %1) -> (!fir.array<30x!fir.type}>>) { + %3 = fir.field_index c, !fir.type}> + %4 = fir.do_loop %arg3 = %c0 to %c19 step %c1 unordered iter_args(%arg4 = %arg2) -> (!fir.array<30x!fir.type}>>) { + // expected-error@+1 {{'fir.array_access' op return type and/or indices do not type check}} + %5 = fir.array_access %1, %arg1, %3, %arg3 : (!fir.array<30x!fir.type}>>, index, !fir.field, index) -> !fir.ref + %6 = fir.call @ifoo(%5) : (!fir.ref) -> i32 + %7 = fir.array_update %arg4, %6, %arg1, %3, %arg3 : (!fir.array<30x!fir.type}>>, i32, index, !fir.field, index) -> !fir.array<30x!fir.type}>> + fir.result %7 : !fir.array<30x!fir.type}>> + } + fir.result %4 : !fir.array<30x!fir.type}>> + } + fir.array_merge_store %1, %2 to %arg0 : !fir.array<30x!fir.type}>>, !fir.array<30x!fir.type}>>, !fir.ref}>>> + return +} +func private @ifoo(!fir.ref) -> i32