Index: flang/include/flang/Optimizer/Dialect/FIROps.td =================================================================== --- flang/include/flang/Optimizer/Dialect/FIROps.td +++ flang/include/flang/Optimizer/Dialect/FIROps.td @@ -85,6 +85,9 @@ def AnyShapeLike : TypeConstraint, "any legal shape type">; def AnyShapeType : Type; +def AnyShapeOrShiftLike : TypeConstraint, "any legal shape or shift type">; +def AnyShapeOrShiftType : Type; def fir_SliceType : Type()">, "slice type">; def AnyEmboxLike : TypeConstraint { + + let summary = "Load an array as a value."; + + let description = [{ + Load an entire array as a single SSA value. + + ```fortran + real :: a(o:n,p:m) + ... + ... = ... a ... + ``` + + One can use `fir.array_load` to produce an ssa-value that captures an + immutable value of the entire array `a`, as in the Fortran array expression + shown above. Subsequent changes to the memory containing the array do not + alter its composite value. This operation let's one load an array as a + value while applying a runtime shape, shift, or slice to the memory + reference, and its semantics guarantee immutability. + + ```mlir + %s = fir.shape_shift %o, %n, %p, %m : (index, index, index, index) -> !fir.shape<2> + // load the entire array 'a' + %v = fir.array_load %a(%s) : (!fir.ref>, !fir.shape<2>) -> !fir.array + // a fir.store here into array %a does not change %v + ``` + }]; + + let arguments = (ins + Arg:$memref, + Optional:$shape, + Optional:$slice, + Variadic:$lenParams + ); + + let results = (outs fir_SequenceType); + + let assemblyFormat = [{ + $memref (`(`$shape^`)`)? (`[`$slice^`]`)? (`typeparams` $lenParams^)? attr-dict `:` functional-type(operands, results) + }]; + + let verifier = [{ return ::verify(*this); }]; + + let extraClassDeclaration = [{ + std::vector getExtents(); + }]; +} + +def fir_ArrayFetchOp : fir_Op<"array_fetch", [NoSideEffect]> { + + let summary = "Fetch the value of an element of an array value"; + + let description = [{ + Fetch the value of an element in an array value. + + ```fortran + real :: a(n,m) + ... + ... a ... + ... a(r,s+1) ... + ``` + + One can use `fir.array_fetch` to fetch the (implied) value of `a(i,j)` in + an array expression as shown above. It can also be used to extract the + element `a(r,s+1)` in the second expression. + + ```mlir + %s = fir.shape %n, %m : (index, index) -> !fir.shape<2> + // load the entire array 'a' + %v = fir.array_load %a(%s) : (!fir.ref>, !fir.shape<2>) -> !fir.array + // fetch the value of one of the array value's elements + %1 = fir.array_fetch %v, %i, %j : (!fir.array, index, index) -> f32 + ``` + + It is only possible to use `array_fetch` on an `array_load` result value. + }]; + + let arguments = (ins + fir_SequenceType:$sequence, + Variadic:$indices + ); + + let results = (outs AnyType:$element); + + let assemblyFormat = [{ + $sequence `,` $indices attr-dict `:` functional-type(operands, results) + }]; + + let verifier = [{ + auto arrTy = sequence().getType().cast(); + if (indices().size() != arrTy.getDimension()) + return emitOpError("number of indices != dimension of array"); + if (element().getType() != arrTy.getEleTy()) + return emitOpError("return type does not match array"); + if (!isa(sequence().getDefiningOp())) + return emitOpError("argument #0 must be result of fir.array_load"); + return mlir::success(); + }]; +} + +def fir_ArrayUpdateOp : fir_Op<"array_update", [NoSideEffect]> { + + let summary = "Update the value of an element of an array value"; + + let description = [{ + Updates the value of an element in an array value. A new array value is + returned where all element values of the input array are identical except + for the selected element which is the value passed in the update. + + ```fortran + real :: a(n,m) + ... + a = ... + ``` + + One can use `fir.array_update` to update the (implied) value of `a(i,j)` + in an array expression as shown above. + + ```mlir + %s = fir.shape %n, %m : (index, index) -> !fir.shape<2> + // load the entire array 'a' + %v = fir.array_load %a(%s) : (!fir.ref>, !fir.shape<2>) -> !fir.array + // update the value of one of the array value's elements + // %r_{ij} = %f if (i,j) = (%i,%j), %v_{ij} otherwise + %r = fir.array_update %v, %f, %i, %j : (!fir.array, f32, index, index) -> !fir.array + fir.array_merge_store %v, %r to %a : !fir.ref> + ``` + + An array value update behaves as if a mapping function from the indices + to the new value has been added, replacing the previous mapping. These + mappings can be added to the ssa-value, but will not be materialized in + memory until the `fir.array_merge_store` is performed. + }]; + + let arguments = (ins + fir_SequenceType:$sequence, + AnyType:$merge, + Variadic:$indices + ); + + let results = (outs fir_SequenceType); + + let assemblyFormat = [{ + $sequence `,` $merge `,` $indices attr-dict `:` functional-type(operands, results) + }]; + + let verifier = [{ + auto arrTy = sequence().getType().cast(); + if (merge().getType() != arrTy.getEleTy()) + return emitOpError("merged value does not have element type"); + if (indices().size() != arrTy.getDimension()) + return emitOpError("number of indices != dimension of array"); + return mlir::success(); + }]; +} + +def fir_ArrayMergeStoreOp : fir_Op<"array_merge_store", [ + TypesMatchWith<"type of 'original' matches element type of 'memref'", + "memref", "original", + "fir::dyn_cast_ptrOrBoxEleTy($_self)">, + TypesMatchWith<"type of 'sequence' matches element type of 'memref'", + "memref", "sequence", + "fir::dyn_cast_ptrOrBoxEleTy($_self)">]> { + + let summary = "Store merged array value to memory."; + + let description = [{ + Store a merged array value to memory. + + ```fortran + real :: a(n,m) + ... + a = ... + ``` + + One can use `fir.array_merge_store` to merge/copy the value of `a` in an + array expression as shown above. + + ```mlir + %v = fir.array_load %a(%shape) : ... + %r = fir.array_update %v, %f, %i, %j : (!fir.array, f32, index, index) -> !fir.array + fir.array_merge_store %v, %r to %a : !fir.ref> + ``` + + This operation merges the original loaded array value, `%v`, with the + chained updates, `%r`, and stores the result to the array at address, `%a`. + }]; + + let arguments = (ins + fir_SequenceType:$original, + fir_SequenceType:$sequence, + Arg:$memref + ); + + let assemblyFormat = "$original `,` $sequence `to` $memref attr-dict `:` type($memref)"; + + let verifier = [{ + if (!isa(original().getDefiningOp())) + return emitOpError("operand #0 must be result of a fir.array_load op"); + return mlir::success(); + }]; +} + +//===----------------------------------------------------------------------===// // Record and array type operations +//===----------------------------------------------------------------------===// + +def fir_ArrayCoorOp : fir_Op<"array_coor", + [NoSideEffect, AttrSizedOperandSegments]> { + + let summary = "Find the coordinate of an element of an array"; + + let description = [{ + Compute the location of an element in an array when the shape of the + array is only known at runtime. + + This operation is intended to capture all the runtime values needed to + compute the address of an array reference in a single high-level op. Given + the following Fortran input: + + ```fortran + real :: a(n,m) + ... + ... a(i,j) ... + ``` + + One can use `fir.array_coor` to determine the address of `a(i,j)`. + + ```mlir + %s = fir.shape %n, %m : (index, index) -> !fir.shape<2> + %1 = fir.array_coor %a(%s) %i, %j : (!fir.ref>, !fir.shape<2>, index, index) -> !fir.ref + ``` + }]; + + let arguments = (ins + AnyRefOrBox:$memref, + Optional:$shape, + Optional:$slice, + Variadic:$indices, + Variadic:$lenParams + ); + + let results = (outs fir_ReferenceType); + + let assemblyFormat = [{ + $memref (`(`$shape^`)`)? (`[`$slice^`]`)? $indices (`typeparams` $lenParams^)? attr-dict `:` functional-type(operands, results) + }]; + + let verifier = [{ return ::verify(*this); }]; +} def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> { + let summary = "Finds the coordinate (location) of a value in memory"; let description = [{ @@ -1762,18 +2019,218 @@ } }]; - let builders = [ - OpBuilderDAG<(ins "StringRef":$fieldName, "Type":$recTy, - CArg<"ValueRange", "{}">:$operands), + let builders = [OpBuilderDAG<(ins "llvm::StringRef":$fieldName, + "mlir::Type":$recTy, CArg<"mlir::ValueRange","{}">:$operands), [{ - $_state.addAttribute(fieldAttrName(), $_builder.getStringAttr(fieldName)); + $_state.addAttribute(fieldAttrName(), + $_builder.getStringAttr(fieldName)); $_state.addAttribute(typeAttrName(), TypeAttr::get(recTy)); $_state.addOperands(operands); - }]>]; + }] + >]; let extraClassDeclaration = [{ static constexpr llvm::StringRef fieldAttrName() { return "field_id"; } static constexpr llvm::StringRef typeAttrName() { return "on_type"; } + llvm::StringRef getFieldName() { return field_id(); } + }]; +} + +def fir_ShapeOp : fir_Op<"shape", [NoSideEffect]> { + + let summary = "generate an abstract shape vector of type `!fir.shape`"; + + let description = [{ + The arguments are an ordered list of integral type values that define the + runtime extent of each dimension of an array. The shape information is + given in the same row-to-column order as Fortran. This abstract shape value + must be applied to a reified object, so all shape information must be + specified. The extent must be nonnegative. + + ```mlir + %d = fir.shape %row_sz, %col_sz : (index, index) -> !fir.shape<2> + ``` + }]; + + let arguments = (ins Variadic:$extents); + + let results = (outs fir_ShapeType); + + let assemblyFormat = [{ + operands attr-dict `:` functional-type(operands, results) + }]; + + let verifier = [{ + auto size = extents().size(); + auto shapeTy = getType().dyn_cast(); + assert(shapeTy && "must be a shape type"); + if (shapeTy.getRank() != size) + return emitOpError("shape type rank mismatch"); + return mlir::success(); + }]; + + let extraClassDeclaration = [{ + std::vector getExtents() { + return {extents().begin(), extents().end()}; + } + }]; +} + +def fir_ShapeShiftOp : fir_Op<"shape_shift", [NoSideEffect]> { + + let summary = [{ + generate an abstract shape and shift vector of type `!fir.shapeshift` + }]; + + let description = [{ + The arguments are an ordered list of integral type values that is a multiple + of 2 in length. Each such pair is defined as: the lower bound and the + extent for that dimension. The shifted shape information is given in the + same row-to-column order as Fortran. This abstract shifted shape value must + be applied to a reified object, so all shifted shape information must be + specified. The extent must be nonnegative. + + ```mlir + %d = fir.shape_shift %lo, %extent : (index, index) -> !fir.shapeshift<1> + ``` + }]; + + let arguments = (ins Variadic:$pairs); + + let results = (outs fir_ShapeShiftType); + + let assemblyFormat = [{ + operands attr-dict `:` functional-type(operands, results) + }]; + + let verifier = [{ + auto size = pairs().size(); + if (size < 2 || size > 16 * 2) + return emitOpError("incorrect number of args"); + if (size % 2 != 0) + return emitOpError("requires a multiple of 2 args"); + auto shapeTy = getType().dyn_cast(); + assert(shapeTy && "must be a shape shift type"); + if (shapeTy.getRank() * 2 != size) + return emitOpError("shape type rank mismatch"); + return mlir::success(); + }]; + + let extraClassDeclaration = [{ + // Logically unzip the origins from the extent values. + std::vector getOrigins() { + std::vector result; + for (auto i : llvm::enumerate(pairs())) + if (!(i.index() & 1)) + result.push_back(i.value()); + return result; + } + + // Logically unzip the extents from the origin values. + std::vector getExtents() { + std::vector result; + for (auto i : llvm::enumerate(pairs())) + if (i.index() & 1) + result.push_back(i.value()); + return result; + } + }]; +} + +def fir_ShiftOp : fir_Op<"shift", [NoSideEffect]> { + + let summary = "generate an abstract shift vector of type `!fir.shift`"; + + let description = [{ + The arguments are an ordered list of integral type values that define the + runtime lower bound of each dimension of an array. The shape information is + given in the same row-to-column order as Fortran. This abstract shift value + must be applied to a reified object, so all shift information must be + specified. + + ```mlir + %d = fir.shift %row_lb, %col_lb : (index, index) -> !fir.shift<2> + ``` + }]; + + let arguments = (ins Variadic:$origins); + + let results = (outs fir_ShiftType); + + let assemblyFormat = [{ + operands attr-dict `:` functional-type(operands, results) + }]; + + let verifier = [{ + auto size = origins().size(); + auto shiftTy = getType().dyn_cast(); + assert(shiftTy && "must be a shift type"); + if (shiftTy.getRank() != size) + return emitOpError("shift type rank mismatch"); + return mlir::success(); + }]; + + let extraClassDeclaration = [{ + std::vector getOrigins() { + return {origins().begin(), origins().end()}; + } + }]; +} + +def fir_SliceOp : fir_Op<"slice", [NoSideEffect, AttrSizedOperandSegments]> { + + let summary = "generate an abstract slice vector of type `!fir.slice`"; + + let description = [{ + The array slicing arguments are an ordered list of integral type values + that must be a multiple of 3 in length. Each such triple is defined as: + the lower bound, the upper bound, and the stride for that dimension, as in + Fortran syntax. Both bounds are inclusive. The array slice information is + given in the same row-to-column order as Fortran. This abstract slice value + must be applied to a reified object, so all slice information must be + specified. The extent must be nonnegative and the stride must not be zero. + + ```mlir + %d = fir.slice %lo, %hi, %step : (index, index, index) -> !fir.slice<1> + ``` + + To support generalized slicing of Fortran's dynamic derived types, a slice + op can be given a component path (narrowing from the product type of the + original array to the specific elemental type of the sliced projection). + + ```mlir + %fld = fir.field_index component, !fir.type + %d = fir.slice %lo, %hi, %step path %fld : (index, index, index, !fir.field) -> !fir.slice<1> + ``` + }]; + + let arguments = (ins + Variadic:$triples, + Variadic:$fields + ); + + let results = (outs fir_SliceType); + + let assemblyFormat = [{ + $triples (`path` $fields^)? attr-dict `:` functional-type(operands, results) + }]; + + let verifier = [{ + auto size = triples().size(); + if (size < 3 || size > 16 * 3) + return emitOpError("incorrect number of args for triple"); + if (size % 3 != 0) + return emitOpError("requires a multiple of 3 args"); + auto sliceTy = getType().dyn_cast(); + assert(sliceTy && "must be a slice type"); + if (sliceTy.getRank() * 3 != size) + return emitOpError("slice type rank mismatch"); + return mlir::success(); + }]; + + let extraClassDeclaration = [{ + unsigned getOutRank() { return getOutputRank(triples()); } + static unsigned getOutputRank(mlir::ValueRange triples); }]; } Index: flang/include/flang/Optimizer/Dialect/FIRType.h =================================================================== --- flang/include/flang/Optimizer/Dialect/FIRType.h +++ flang/include/flang/Optimizer/Dialect/FIRType.h @@ -91,6 +91,10 @@ /// not a memory reference type, then returns a null `Type`. mlir::Type dyn_cast_ptrEleTy(mlir::Type t); +/// Extract the `Type` pointed to from a FIR memory reference or box type. If +/// `t` is not a memory reference or box type, then returns a null `Type`. +mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t); + // Intrinsic types /// Model of a Fortran INTEGER intrinsic type, including the KIND type Index: flang/lib/Optimizer/Dialect/FIROps.cpp =================================================================== --- flang/lib/Optimizer/Dialect/FIROps.cpp +++ flang/lib/Optimizer/Dialect/FIROps.cpp @@ -5,6 +5,10 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRAttr.h" @@ -115,6 +119,94 @@ return HeapType::get(intype); } +//===----------------------------------------------------------------------===// +// ArrayCoorOp +//===----------------------------------------------------------------------===// + +static mlir::LogicalResult verify(fir::ArrayCoorOp op) { + auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(op.memref().getType()); + if (!eleTy) + return op.emitOpError("must be a reference or box type"); + auto arrTy = eleTy.dyn_cast(); + if (!arrTy) + return op.emitOpError("must be a reference to an array"); + auto arrDim = arrTy.getDimension(); + + if (auto shapeOp = op.shape()) { + auto shapeTy = shapeOp.getType(); + unsigned shapeTyRank = 0; + if (auto s = shapeTy.dyn_cast()) { + shapeTyRank = s.getRank(); + } else if (auto ss = shapeTy.dyn_cast()) { + shapeTyRank = ss.getRank(); + } else { + auto s = shapeTy.cast(); + shapeTyRank = s.getRank(); + if (!op.memref().getType().isa()) + return op.emitOpError("shift can only be provided with fir.box memref"); + } + if (arrDim && arrDim != shapeTyRank) + return op.emitOpError("rank of dimension mismatched"); + if (shapeTyRank != op.indices().size()) + return op.emitOpError("number of indices do not match dim rank"); + } + + if (auto sliceOp = op.slice()) + if (auto sliceTy = sliceOp.getType().dyn_cast()) + if (sliceTy.getRank() != arrDim) + return op.emitOpError("rank of dimension in slice mismatched"); + + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// ArrayLoadOp +//===----------------------------------------------------------------------===// + +std::vector fir::ArrayLoadOp::getExtents() { + if (auto sh = shape()) + if (auto *op = sh.getDefiningOp()) { + if (auto shOp = dyn_cast(op)) + return shOp.getExtents(); + return cast(op).getExtents(); + } + return {}; +} + +static mlir::LogicalResult verify(fir::ArrayLoadOp op) { + auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(op.memref().getType()); + if (!eleTy) + return op.emitOpError("must be a reference or box type"); + auto arrTy = eleTy.dyn_cast(); + if (!arrTy) + return op.emitOpError("must be a reference to an array"); + auto arrDim = arrTy.getDimension(); + + if (auto shapeOp = op.shape()) { + auto shapeTy = shapeOp.getType(); + unsigned shapeTyRank = 0; + if (auto s = shapeTy.dyn_cast()) { + shapeTyRank = s.getRank(); + } else if (auto ss = shapeTy.dyn_cast()) { + shapeTyRank = ss.getRank(); + } else { + auto s = shapeTy.cast(); + shapeTyRank = s.getRank(); + if (!op.memref().getType().isa()) + return op.emitOpError("shift can only be provided with fir.box memref"); + } + if (arrDim && arrDim != shapeTyRank) + return op.emitOpError("rank of dimension mismatched"); + } + + if (auto sliceOp = op.slice()) + if (auto sliceTy = sliceOp.getType().dyn_cast()) + if (sliceTy.getRank() != arrDim) + return op.emitOpError("rank of dimension in slice mismatched"); + + return mlir::success(); +} + //===----------------------------------------------------------------------===// // BoxAddrOp //===----------------------------------------------------------------------===// Index: flang/lib/Optimizer/Dialect/FIRType.cpp =================================================================== --- flang/lib/Optimizer/Dialect/FIRType.cpp +++ flang/lib/Optimizer/Dialect/FIRType.cpp @@ -725,6 +725,19 @@ .Default([](mlir::Type) { return mlir::Type{}; }); } +mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) { + return llvm::TypeSwitch(t) + .Case( + [](auto p) { return p.getEleTy(); }) + .Case([](auto p) { + auto eleTy = p.getEleTy(); + if (auto ty = fir::dyn_cast_ptrEleTy(eleTy)) + return ty; + return eleTy; + }) + .Default([](mlir::Type) { return mlir::Type{}; }); +} + } // namespace fir // Len Index: flang/test/Fir/fir-ops.fir =================================================================== --- flang/test/Fir/fir-ops.fir +++ flang/test/Fir/fir-ops.fir @@ -621,5 +621,17 @@ // CHECK: [[ARR2:%.*]] = fir.zero_bits !fir.array<10xi32> %arr2 = fir.zero_bits !fir.array<10xi32> + + // CHECK: [[SHAPE:%.*]] = fir.shape_shift [[INDXM:%.*]], [[INDXN:%.*]], [[INDXO:%.*]], [[INDXP:%.*]] : (index, index, index, index) -> !fir.shapeshift<2> + // CHECK: [[AV1:%.*]] = fir.array_load [[ARR1]]([[SHAPE]]) : (!fir.ref>, !fir.shapeshift<2>) -> !fir.array + // CHECK: [[FVAL:%.*]] = fir.array_fetch [[AV1]], [[I10]], [[J20]] : (!fir.array, index, index) -> f32 + // CHECK: [[AV2:%.*]] = fir.array_update [[AV1]], [[FVAL]], [[I10]], [[J20]] : (!fir.array, f32, index, index) -> !fir.array + // CHECK: fir.array_merge_store [[AV1]], [[AV2]] to [[ARR1]] : !fir.ref> + %s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2> + %av1 = fir.array_load %arr1(%s) : (!fir.ref>, !fir.shapeshift<2>) -> !fir.array + %f = fir.array_fetch %av1, %i10, %j20 : (!fir.array, index, index) -> f32 + %av2 = fir.array_update %av1, %f, %i10, %j20 : (!fir.array, f32, index, index) -> !fir.array + fir.array_merge_store %av1, %av2 to %arr1 : !fir.ref> + return }