diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -1658,7 +1658,8 @@ }]; } -def fir_ArrayFetchOp : fir_Op<"array_fetch", [NoSideEffect]> { +def fir_ArrayFetchOp : fir_Op<"array_fetch", [AttrSizedOperandSegments, + NoSideEffect]> { let summary = "Fetch the value of an element of an array value"; @@ -1689,28 +1690,22 @@ let arguments = (ins fir_SequenceType:$sequence, - Variadic:$indices + Variadic:$indices, + Variadic:$typeparams ); let results = (outs AnyType:$element); let assemblyFormat = [{ - $sequence `,` $indices attr-dict `:` functional-type(operands, results) + $sequence `,` $indices (`typeparams` $typeparams^)? attr-dict `:` + functional-type(operands, results) }]; - let verifier = [{ - auto arrTy = sequence().getType().cast(); - if (indices().size() != arrTy.getDimension()) - return emitOpError("number of indices != dimension of array"); - if (element().getType() != arrTy.getEleTy()) - return emitOpError("return type does not match array"); - if (!isa(sequence().getDefiningOp())) - return emitOpError("argument #0 must be result of fir.array_load"); - return mlir::success(); - }]; + let verifier = "return ::verify(*this);"; } -def fir_ArrayUpdateOp : fir_Op<"array_update", [NoSideEffect]> { +def fir_ArrayUpdateOp : fir_Op<"array_update", [AttrSizedOperandSegments, + NoSideEffect]> { let summary = "Update the value of an element of an array value"; @@ -1747,32 +1742,22 @@ let arguments = (ins fir_SequenceType:$sequence, AnyType:$merge, - Variadic:$indices + Variadic:$indices, + Variadic:$typeparams ); let results = (outs fir_SequenceType); let assemblyFormat = [{ - $sequence `,` $merge `,` $indices attr-dict `:` functional-type(operands, results) + $sequence `,` $merge `,` $indices (`typeparams` $typeparams^)? attr-dict + `:` functional-type(operands, results) }]; - let verifier = [{ - auto arrTy = sequence().getType().cast(); - if (merge().getType() != arrTy.getEleTy()) - return emitOpError("merged value does not have element type"); - if (indices().size() != arrTy.getDimension()) - return emitOpError("number of indices != dimension of array"); - return mlir::success(); - }]; + let verifier = "return ::verify(*this);"; } -def fir_ArrayMergeStoreOp : fir_Op<"array_merge_store", [ - TypesMatchWith<"type of 'original' matches element type of 'memref'", - "memref", "original", - "fir::dyn_cast_ptrOrBoxEleTy($_self)">, - TypesMatchWith<"type of 'sequence' matches element type of 'memref'", - "memref", "sequence", - "fir::dyn_cast_ptrOrBoxEleTy($_self)">]> { +def fir_ArrayMergeStoreOp : fir_Op<"array_merge_store", + [AttrSizedOperandSegments]> { let summary = "Store merged array value to memory."; @@ -1801,16 +1786,17 @@ let arguments = (ins fir_SequenceType:$original, fir_SequenceType:$sequence, - Arg:$memref + Arg:$memref, + Optional:$slice, + Variadic:$typeparams ); - let assemblyFormat = "$original `,` $sequence `to` $memref attr-dict `:` type($memref)"; - - let verifier = [{ - if (!isa(original().getDefiningOp())) - return emitOpError("operand #0 must be result of a fir.array_load op"); - return mlir::success(); + let assemblyFormat = [{ + $original `,` $sequence `to` $memref (`[` $slice^ `]`)? (`typeparams` + $typeparams^)? attr-dict `:` type(operands) }]; + + let verifier = "return ::verify(*this);"; } //===----------------------------------------------------------------------===// diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h --- a/flang/include/flang/Optimizer/Dialect/FIRType.h +++ b/flang/include/flang/Optimizer/Dialect/FIRType.h @@ -33,6 +33,7 @@ class DialectAsmPrinter; class ComplexType; class FloatType; +class ValueRange; } // namespace mlir namespace fir { @@ -122,6 +123,9 @@ return t.isa() || t.isa(); } +/// Is `t` a CHARACTER type? Does not check the length. +inline bool isa_char(mlir::Type t) { return t.isa(); } + /// Is `t` a CHARACTER type with a LEN other than 1? inline bool isa_char_string(mlir::Type t) { if (auto ct = t.dyn_cast_or_null()) @@ -134,6 +138,13 @@ /// of unknown rank or type. bool isa_unknown_size_box(mlir::Type t); +/// If `t` is a SequenceType return its element type, otherwise return `t`. +inline mlir::Type unwrapSequenceType(mlir::Type t) { + if (auto seqTy = t.dyn_cast()) + return seqTy.getEleTy(); + return t; +} + #ifndef NDEBUG // !fir.ptr and !fir.heap where X is !fir.ptr, !fir.heap, or !fir.ref // is undefined and disallowed. @@ -142,6 +153,11 @@ } #endif +/// Apply the components specified by `path` to `rootTy` to determine the type +/// of the resulting component element. `rootTy` should be an aggregate type. +/// Returns null on error. +mlir::Type applyPathToType(mlir::Type rootTy, mlir::ValueRange path); + } // namespace fir #endif // FORTRAN_OPTIMIZER_DIALECT_FIRTYPE_H diff --git a/flang/include/flang/Optimizer/Support/Utils.h b/flang/include/flang/Optimizer/Support/Utils.h new file mode 100644 --- /dev/null +++ b/flang/include/flang/Optimizer/Support/Utils.h @@ -0,0 +1,26 @@ +//===-- Optimizer/Support/Utils.h -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_OPTIMIZER_SUPPORT_UTILS_H +#define FORTRAN_OPTIMIZER_SUPPORT_UTILS_H + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BuiltinAttributes.h" + +namespace fir { +/// Return the integer value of a ConstantOp. +inline std::int64_t toInt(mlir::ConstantOp cop) { + return cop.getValue().cast().getValue().getSExtValue(); +} +} // namespace fir + +#endif // FORTRAN_OPTIMIZER_SUPPORT_UTILS_H diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -14,6 +14,7 @@ #include "flang/Optimizer/Dialect/FIRAttr.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Support/Utils.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BuiltinOps.h" @@ -153,6 +154,19 @@ // ArrayLoadOp //===----------------------------------------------------------------------===// +static mlir::Type adjustedElementType(mlir::Type t) { + if (auto ty = t.dyn_cast()) { + auto eleTy = ty.getEleTy(); + if (fir::isa_char(eleTy)) + return eleTy; + if (fir::isa_derived(eleTy)) + return eleTy; + if (eleTy.isa()) + return eleTy; + } + return t; +} + std::vector fir::ArrayLoadOp::getExtents() { if (auto sh = shape()) if (auto *op = sh.getDefiningOp()) { @@ -195,6 +209,90 @@ return mlir::success(); } +//===----------------------------------------------------------------------===// +// ArrayMergeStoreOp +//===----------------------------------------------------------------------===// + +static mlir::LogicalResult verify(fir::ArrayMergeStoreOp op) { + if (!isa(op.original().getDefiningOp())) + return op.emitOpError("operand #0 must be result of a fir.array_load op"); + if (auto sl = op.slice()) { + if (auto *slOp = sl.getDefiningOp()) { + auto sliceOp = mlir::cast(slOp); + if (!sliceOp.fields().empty()) { + // This is an intra-object merge, where the slice is projecting the + // subfields that are to be overwritten by the merge operation. + auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(op.memref().getType()); + if (auto seqTy = eleTy.dyn_cast()) { + auto projTy = + fir::applyPathToType(seqTy.getEleTy(), sliceOp.fields()); + if (fir::unwrapSequenceType(op.original().getType()) != projTy) + return op.emitOpError( + "type of origin does not match sliced memref type"); + if (fir::unwrapSequenceType(op.sequence().getType()) != projTy) + return op.emitOpError( + "type of sequence does not match sliced memref type"); + return mlir::success(); + } + return op.emitOpError("referenced type is not an array"); + } + } + return mlir::success(); + } + auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(op.memref().getType()); + if (op.original().getType() != eleTy) + return op.emitOpError("type of origin does not match memref element type"); + if (op.sequence().getType() != eleTy) + return op.emitOpError( + "type of sequence does not match memref element type"); + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// ArrayFetchOp +//===----------------------------------------------------------------------===// + +// Template function used for both array_fetch and array_update verification. +template +mlir::Type validArraySubobject(A op) { + auto ty = op.sequence().getType(); + return fir::applyPathToType(ty, op.indices()); +} + +static mlir::LogicalResult verify(fir::ArrayFetchOp op) { + auto arrTy = op.sequence().getType().cast(); + auto indSize = op.indices().size(); + if (indSize < arrTy.getDimension()) + return op.emitOpError("number of indices != dimension of array"); + if (indSize == arrTy.getDimension() && + ::adjustedElementType(op.element().getType()) != arrTy.getEleTy()) + return op.emitOpError("return type does not match array"); + auto ty = validArraySubobject(op); + if (!ty || ty != ::adjustedElementType(op.getType())) + return op.emitOpError("return type and/or indices do not type check"); + if (!isa(op.sequence().getDefiningOp())) + return op.emitOpError("argument #0 must be result of fir.array_load"); + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// ArrayUpdateOp +//===----------------------------------------------------------------------===// + +static mlir::LogicalResult verify(fir::ArrayUpdateOp op) { + auto arrTy = op.sequence().getType().cast(); + auto indSize = op.indices().size(); + if (indSize < arrTy.getDimension()) + return op.emitOpError("number of indices != dimension of array"); + if (indSize == arrTy.getDimension() && + ::adjustedElementType(op.merge().getType()) != arrTy.getEleTy()) + return op.emitOpError("merged value does not have element type"); + auto ty = validArraySubobject(op); + if (!ty || ty != ::adjustedElementType(op.merge().getType())) + return op.emitOpError("merged value and/or indices do not type check"); + return mlir::success(); +} + //===----------------------------------------------------------------------===// // BoxAddrOp //===----------------------------------------------------------------------===// @@ -2197,6 +2295,47 @@ return false; } +mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) { + for (auto i = path.begin(), end = path.end(); eleTy && i < end;) { + eleTy = llvm::TypeSwitch(eleTy) + .Case([&](fir::RecordType ty) { + if (auto *op = (*i++).getDefiningOp()) { + if (auto off = mlir::dyn_cast(op)) + return ty.getType(off.getFieldName()); + if (auto off = mlir::dyn_cast(op)) + return ty.getType(fir::toInt(off)); + } + return mlir::Type{}; + }) + .Case([&](fir::SequenceType ty) { + bool valid = true; + const auto rank = ty.getDimension(); + for (std::remove_const_t ii = 0; + valid && ii < rank; ++ii) + valid = i < end && fir::isa_integer((*i++).getType()); + return valid ? ty.getEleTy() : mlir::Type{}; + }) + .Case([&](mlir::TupleType ty) { + if (auto *op = (*i++).getDefiningOp()) + if (auto off = mlir::dyn_cast(op)) + return ty.getType(fir::toInt(off)); + return mlir::Type{}; + }) + .Case([&](fir::ComplexType ty) { + if (fir::isa_integer((*i++).getType())) + return ty.getElementType(); + return mlir::Type{}; + }) + .Case([&](mlir::ComplexType ty) { + if (fir::isa_integer((*i++).getType())) + return ty.getElementType(); + return mlir::Type{}; + }) + .Default([&](const auto &) { return mlir::Type{}; }); + } + return eleTy; +} + // Tablegen operators #define GET_OP_CLASSES diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir --- a/flang/test/Fir/fir-ops.fir +++ b/flang/test/Fir/fir-ops.fir @@ -631,12 +631,12 @@ // CHECK: [[AV1:%.*]] = fir.array_load [[ARR1]]([[SHAPE]]) : (!fir.ref>, !fir.shapeshift<2>) -> !fir.array // CHECK: [[FVAL:%.*]] = fir.array_fetch [[AV1]], [[I10]], [[J20]] : (!fir.array, index, index) -> f32 // CHECK: [[AV2:%.*]] = fir.array_update [[AV1]], [[FVAL]], [[I10]], [[J20]] : (!fir.array, f32, index, index) -> !fir.array - // CHECK: fir.array_merge_store [[AV1]], [[AV2]] to [[ARR1]] : !fir.ref> + // CHECK: fir.array_merge_store [[AV1]], [[AV2]] to [[ARR1]] : !fir.array, !fir.array, !fir.ref> %s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2> %av1 = fir.array_load %arr1(%s) : (!fir.ref>, !fir.shapeshift<2>) -> !fir.array %f = fir.array_fetch %av1, %i10, %j20 : (!fir.array, index, index) -> f32 %av2 = fir.array_update %av1, %f, %i10, %j20 : (!fir.array, f32, index, index) -> !fir.array - fir.array_merge_store %av1, %av2 to %arr1 : !fir.ref> + fir.array_merge_store %av1, %av2 to %arr1 : !fir.array, !fir.array, !fir.ref> return } diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir --- a/flang/test/Fir/invalid.fir +++ b/flang/test/Fir/invalid.fir @@ -494,3 +494,57 @@ fir.save_result %res to %buffer(%shape) typeparams %n : !fir.array, !fir.ref>, !fir.shape<1>, index return } + +// ----- + +func @test_misc_ops(%arr1 : !fir.ref>, %m : index, %n : index, %o : index, %p : index) { + %s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2> + %av1 = fir.array_load %arr1(%s) : (!fir.ref>, !fir.shapeshift<2>) -> !fir.array + // expected-error@+1 {{'fir.array_fetch' op number of indices != dimension of array}} + %f = fir.array_fetch %av1, %m : (!fir.array, index) -> f32 + return +} + +// ----- + +func @test_misc_ops(%arr1 : !fir.ref>, %m : index, %n : index, %o : index, %p : index) { + %s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2> + %av1 = fir.array_load %arr1(%s) : (!fir.ref>, !fir.shapeshift<2>) -> !fir.array + // expected-error@+1 {{'fir.array_fetch' op return type does not match array}} + %f = fir.array_fetch %av1, %m, %n : (!fir.array, index, index) -> i32 + return +} + +// ----- + +func @test_misc_ops(%arr1 : !fir.ref>, %m : index, %n : index, %o : index, %p : index) { + %s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2> + %av1 = fir.array_load %arr1(%s) : (!fir.ref>, !fir.shapeshift<2>) -> !fir.array + %f = fir.array_fetch %av1, %m, %n : (!fir.array, index, index) -> f32 + // expected-error@+1 {{'fir.array_update' op number of indices != dimension of array}} + %av2 = fir.array_update %av1, %f, %m : (!fir.array, f32, index) -> !fir.array + return +} + +// ----- + +func @test_misc_ops(%arr1 : !fir.ref>, %m : index, %n : index, %o : index, %p : index) { + %s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2> + %av1 = fir.array_load %arr1(%s) : (!fir.ref>, !fir.shapeshift<2>) -> !fir.array + %c0 = constant 0 : i32 + // expected-error@+1 {{'fir.array_update' op merged value does not have element type}} + %av2 = fir.array_update %av1, %c0, %m, %n : (!fir.array, i32, index, index) -> !fir.array + return +} + +// ----- + +func @test_misc_ops(%arr1 : !fir.ref>, %m : index, %n : index, %o : index, %p : index) { + %s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2> + %av1 = fir.array_load %arr1(%s) : (!fir.ref>, !fir.shapeshift<2>) -> !fir.array + %f = fir.array_fetch %av1, %m, %n : (!fir.array, index, index) -> f32 + %av2 = fir.array_update %av1, %f, %m, %n : (!fir.array, f32, index, index) -> !fir.array + // expected-error@+1 {{'fir.array_merge_store' op operand #0 must be result of a fir.array_load op}} + fir.array_merge_store %av2, %av2 to %arr1 : !fir.array, !fir.array, !fir.ref> + return +}