diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h --- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h +++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h @@ -21,6 +21,8 @@ #include "flang/Optimizer/Support/KindMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" namespace fir { class AbstractArrayBox; @@ -57,6 +59,13 @@ /// Get a reference to the kind map. const fir::KindMapping &getKindMap() { return kindMap; } + /// The LHS and RHS are not always in agreement in terms of + /// type. In some cases, the disagreement is between COMPLEX and other scalar + /// types. In that case, the conversion must insert/extract out of a COMPLEX + /// value to have the proper semantics and be strongly typed. + mlir::Value convertWithSemantics(mlir::Location loc, mlir::Type toTy, + mlir::Value val); + /// Get the entry block of the current Function mlir::Block *getEntryBlock() { return &getFunction().front(); } @@ -189,11 +198,6 @@ mlir::StringAttr createWeakLinkage() { return getStringAttr("weak"); } - /// Cast the input value to IndexType. - mlir::Value convertToIndexType(mlir::Location loc, mlir::Value val) { - return createConvert(loc, getIndexType(), val); - } - /// Get a function by name. If the function exists in the current module, it /// is returned. Otherwise, a null FuncOp is returned. mlir::FuncOp getNamedFunction(llvm::StringRef name) { @@ -243,6 +247,11 @@ return createFunction(loc, module, name, ty); } + /// Cast the input value to IndexType. + mlir::Value convertToIndexType(mlir::Location loc, mlir::Value val) { + return createConvert(loc, getIndexType(), val); + } + /// Construct one of the two forms of shape op from an array box. mlir::Value genShape(mlir::Location loc, const fir::AbstractArrayBox &arr); mlir::Value genShape(mlir::Location loc, llvm::ArrayRef shift, @@ -253,6 +262,13 @@ /// this may create a `fir.shift` op. mlir::Value createShape(mlir::Location loc, const fir::ExtendedValue &exv); + /// Create a boxed value (Fortran descriptor) to be passed to the runtime. + /// \p exv is an extended value holding a memory reference to the object that + /// must be boxed. This function will crash if provided something that is not + /// a memory reference type. + /// Array entities are boxed with a shape and character with their length. + mlir::Value createBox(mlir::Location loc, const fir::ExtendedValue &exv); + /// Create constant i1 with value 1. if \p b is true or 0. otherwise mlir::Value createBool(mlir::Location loc, bool b) { return createIntegerConstant(loc, getIntegerType(1), b ? 1 : 0); @@ -326,6 +342,12 @@ /// Generate code testing \p addr is a null address. mlir::Value genIsNull(mlir::Location loc, mlir::Value addr); + /// Compute the extent of (lb:ub:step) as max((ub-lb+step)/step, 0). See + /// Fortran 2018 9.5.3.3.2 section for more details. + mlir::Value genExtentFromTriplet(mlir::Location loc, mlir::Value lb, + mlir::Value ub, mlir::Value step, + mlir::Type type); + private: const KindMapping &kindMap; }; @@ -345,6 +367,10 @@ mlir::Value readCharLen(fir::FirOpBuilder &builder, mlir::Location loc, const fir::ExtendedValue &box); +/// Read or get the extent in dimension \p dim of the array described by \p box. +mlir::Value readExtent(fir::FirOpBuilder &builder, mlir::Location loc, + const fir::ExtendedValue &box, unsigned dim); + /// Read extents from \p box. llvm::SmallVector readExtents(fir::FirOpBuilder &builder, mlir::Location loc, @@ -369,6 +395,12 @@ /// to hint at the origin of the identifier. std::string uniqueCGIdent(llvm::StringRef prefix, llvm::StringRef name); +/// Lowers the extents from the sequence type to Values. +/// Any unknown extents are lowered to undefined values. +llvm::SmallVector createExtents(fir::FirOpBuilder &builder, + mlir::Location loc, + fir::SequenceType seqTy); + //===----------------------------------------------------------------------===// // Location helpers //===----------------------------------------------------------------------===// @@ -379,6 +411,37 @@ /// Generate a constant of the given type with the location line number mlir::Value locationToLineNo(fir::FirOpBuilder &, mlir::Location, mlir::Type); +//===----------------------------------------------------------------------===// +// ExtendedValue helpers +//===----------------------------------------------------------------------===// + +/// Return the extended value for a component of a derived type instance given +/// the extended value \p obj of the derived type instance and the address of +/// the component. +fir::ExtendedValue componentToExtendedValue(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::Value component); + +/// Given the address of an array element and the ExtendedValue describing the +/// array, returns the ExtendedValue describing the array element. The purpose +/// is to propagate the length parameters of the array to the element. +/// This can be used for elements of `array` or `array(i:j:k)`. If \p element +/// belongs to an array section `array%x` whose base is \p array, +/// arraySectionElementToExtendedValue must be used instead. +fir::ExtendedValue arrayElementToExtendedValue(fir::FirOpBuilder &builder, + mlir::Location loc, + const fir::ExtendedValue &array, + mlir::Value element); + +/// Assign \p rhs to \p lhs. Both \p rhs and \p lhs must be scalar derived +/// types. The assignment follows Fortran intrinsic assignment semantic for +/// derived types (10.2.1.3 point 13). +void genRecordAssignment(fir::FirOpBuilder &builder, mlir::Location loc, + const fir::ExtendedValue &lhs, + const fir::ExtendedValue &rhs); + +mlir::TupleType getRaggedArrayHeaderType(fir::FirOpBuilder &builder); + } // namespace fir::factory #endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H diff --git a/flang/include/flang/Optimizer/Builder/Runtime/Assign.h b/flang/include/flang/Optimizer/Builder/Runtime/Assign.h new file mode 100644 --- /dev/null +++ b/flang/include/flang/Optimizer/Builder/Runtime/Assign.h @@ -0,0 +1,32 @@ +//===-- Assign.h - generate assignment runtime API calls ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_OPTIMIZER_BUILDER_RUNTIME_ASSIGN_H +#define FORTRAN_OPTIMIZER_BUILDER_RUNTIME_ASSIGN_H + +namespace mlir { +class Value; +class Location; +} // namespace mlir + +namespace fir { +class FirOpBuilder; +} // namespace fir + +namespace fir::runtime { + +/// Generate runtime call to assign \p sourceBox to \p destBox. +/// \p destBox must be a fir.ref> and \p sourceBox a fir.box. +/// \p destBox Fortran descriptor may be modified if destBox is an allocatable +/// according to Fortran allocatable assignment rules, otherwise it is not +/// modified. +void genAssign(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value destBox, mlir::Value sourceBox); + +} // namespace fir::runtime +#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_ASSIGN_H diff --git a/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h b/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h new file mode 100644 --- /dev/null +++ b/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h @@ -0,0 +1,407 @@ +//===-- RTBuilder.h ---------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines some C++17 template classes that are used to convert the +/// signatures of plain old C functions into a model that can be used to +/// generate MLIR calls to those functions. This can be used to autogenerate +/// tables at compiler compile-time to call runtime support code. +/// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RTBUILDER_H +#define FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RTBUILDER_H + +#include "flang/Common/Fortran.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "llvm/ADT/SmallVector.h" +#include + +// Incomplete type indicating C99 complex ABI in interfaces. Beware, _Complex +// and std::complex are layout compatible, but not compatible in all ABI call +// interface (e.g. X86 32 bits). _Complex is not standard C++, so do not use +// it here. +struct c_float_complex_t; +struct c_double_complex_t; + +namespace Fortran::runtime { +class Descriptor; +} // namespace Fortran::runtime + +namespace fir::runtime { + +using TypeBuilderFunc = mlir::Type (*)(mlir::MLIRContext *); +using FuncTypeBuilderFunc = mlir::FunctionType (*)(mlir::MLIRContext *); + +//===----------------------------------------------------------------------===// +// Type builder models +//===----------------------------------------------------------------------===// + +// TODO: all usages of sizeof in this file assume build == host == target. +// This will need to be re-visited for cross compilation. + +/// Return a function that returns the type signature model for the type `T` +/// when provided an MLIRContext*. This allows one to translate C(++) function +/// signatures from runtime header files to MLIR signatures into a static table +/// at compile-time. +/// +/// For example, when `T` is `int`, return a function that returns the MLIR +/// standard type `i32` when `sizeof(int)` is 4. +template +static constexpr TypeBuilderFunc getModel(); +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, 8 * sizeof(short int)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, 8 * sizeof(int)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + TypeBuilderFunc f{getModel()}; + return fir::ReferenceType::get(f(context)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return fir::ReferenceType::get(mlir::IntegerType::get(context, 8)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return getModel(); +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return fir::ReferenceType::get(mlir::IntegerType::get(context, 16)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return fir::ReferenceType::get(mlir::IntegerType::get(context, 32)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, 8 * sizeof(signed char)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return fir::LLVMPointerType::get(mlir::IntegerType::get(context, 8)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return fir::ReferenceType::get( + fir::LLVMPointerType::get(mlir::IntegerType::get(context, 8))); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, 8 * sizeof(long)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + TypeBuilderFunc f{getModel()}; + return fir::ReferenceType::get(f(context)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return getModel(); +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, 8 * sizeof(std::size_t)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + TypeBuilderFunc f{getModel()}; + return fir::ReferenceType::get(f(context)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return getModel(); +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, 8 * sizeof(unsigned long)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, 8 * sizeof(unsigned long long)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::FloatType::getF64(context); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + TypeBuilderFunc f{getModel()}; + return fir::ReferenceType::get(f(context)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return getModel(); +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::FloatType::getF32(context); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + TypeBuilderFunc f{getModel()}; + return fir::ReferenceType::get(f(context)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return getModel(); +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, 1); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + TypeBuilderFunc f{getModel()}; + return fir::ReferenceType::get(f(context)); + }; +} +template <> +constexpr TypeBuilderFunc getModel &>() { + return [](mlir::MLIRContext *context) -> mlir::Type { + auto ty = mlir::ComplexType::get(mlir::FloatType::getF32(context)); + return fir::ReferenceType::get(ty); + }; +} +template <> +constexpr TypeBuilderFunc getModel &>() { + return [](mlir::MLIRContext *context) -> mlir::Type { + auto ty = mlir::ComplexType::get(mlir::FloatType::getF64(context)); + return fir::ReferenceType::get(ty); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return fir::ComplexType::get(context, sizeof(float)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return fir::ComplexType::get(context, sizeof(double)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return fir::BoxType::get(mlir::NoneType::get(context)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return fir::ReferenceType::get( + fir::BoxType::get(mlir::NoneType::get(context))); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return getModel(); +} +template <> +constexpr TypeBuilderFunc getModel() { + return getModel(); +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, + sizeof(Fortran::common::TypeCategory) * 8); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::NoneType::get(context); + }; +} + +template +struct RuntimeTableKey; +template +struct RuntimeTableKey { + static constexpr FuncTypeBuilderFunc getTypeModel() { + return [](mlir::MLIRContext *ctxt) { + TypeBuilderFunc ret = getModel(); + std::array args = {getModel()...}; + mlir::Type retTy = ret(ctxt); + llvm::SmallVector argTys; + for (auto f : args) + argTys.push_back(f(ctxt)); + return mlir::FunctionType::get(ctxt, argTys, {retTy}); + }; + } +}; + +//===----------------------------------------------------------------------===// +// Runtime table building (constexpr folded) +//===----------------------------------------------------------------------===// + +template +using RuntimeIdentifier = std::integer_sequence; + +namespace details { +template +static constexpr std::integer_sequence +concat(std::integer_sequence, std::integer_sequence) { + return {}; +} +template +static constexpr auto concat(std::integer_sequence, + std::integer_sequence, Cs...) { + return concat(std::integer_sequence{}, Cs{}...); +} +template +static constexpr std::integer_sequence concat(std::integer_sequence) { + return {}; +} +template +static constexpr auto filterZero(std::integer_sequence) { + if constexpr (a != 0) { + return std::integer_sequence{}; + } else { + return std::integer_sequence{}; + } +} +template +static constexpr auto filter(std::integer_sequence) { + if constexpr (sizeof...(b) > 0) { + return details::concat(filterZero(std::integer_sequence{})...); + } else { + return std::integer_sequence{}; + } +} +} // namespace details + +template +struct RuntimeTableEntry; +template +struct RuntimeTableEntry, RuntimeIdentifier> { + static constexpr FuncTypeBuilderFunc getTypeModel() { + return RuntimeTableKey::getTypeModel(); + } + static constexpr const char name[sizeof...(Cs) + 1] = {Cs..., '\0'}; +}; + +#undef E +#define E(L, I) (I < sizeof(L) / sizeof(*L) ? L[I] : 0) +#define QuoteKey(X) #X +#define ExpandAndQuoteKey(X) QuoteKey(X) +#define MacroExpandKey(X) \ + E(X, 0), E(X, 1), E(X, 2), E(X, 3), E(X, 4), E(X, 5), E(X, 6), E(X, 7), \ + E(X, 8), E(X, 9), E(X, 10), E(X, 11), E(X, 12), E(X, 13), E(X, 14), \ + E(X, 15), E(X, 16), E(X, 17), E(X, 18), E(X, 19), E(X, 20), E(X, 21), \ + E(X, 22), E(X, 23), E(X, 24), E(X, 25), E(X, 26), E(X, 27), E(X, 28), \ + E(X, 29), E(X, 30), E(X, 31), E(X, 32), E(X, 33), E(X, 34), E(X, 35), \ + E(X, 36), E(X, 37), E(X, 38), E(X, 39), E(X, 40), E(X, 41), E(X, 42), \ + E(X, 43), E(X, 44), E(X, 45), E(X, 46), E(X, 47), E(X, 48), E(X, 49) +#define ExpandKey(X) MacroExpandKey(QuoteKey(X)) +#define FullSeq(X) std::integer_sequence +#define AsSequence(X) decltype(fir::runtime::details::filter(FullSeq(X){})) +#define mkKey(X) \ + fir::runtime::RuntimeTableEntry, \ + AsSequence(X)> +#define mkRTKey(X) mkKey(RTNAME(X)) + +/// Get (or generate) the MLIR FuncOp for a given runtime function. Its template +/// argument is intended to be of the form: +/// Clients should add "using namespace Fortran::runtime" +/// in order to use this function. +template +static mlir::FuncOp getRuntimeFunc(mlir::Location loc, + fir::FirOpBuilder &builder) { + auto name = RuntimeEntry::name; + auto func = builder.getNamedFunction(name); + if (func) + return func; + auto funTy = RuntimeEntry::getTypeModel()(builder.getContext()); + func = builder.createFunction(loc, name, funTy); + func->setAttr("fir.runtime", builder.getUnitAttr()); + return func; +} + +namespace helper { +template +void createArguments(llvm::SmallVectorImpl &result, + fir::FirOpBuilder &builder, mlir::Location loc, + mlir::FunctionType fTy, A arg) { + result.emplace_back(builder.createConvert(loc, fTy.getInput(N), arg)); +} + +template +void createArguments(llvm::SmallVectorImpl &result, + fir::FirOpBuilder &builder, mlir::Location loc, + mlir::FunctionType fTy, A arg, As... args) { + result.emplace_back(builder.createConvert(loc, fTy.getInput(N), arg)); + createArguments(result, builder, loc, fTy, args...); +} +} // namespace helper + +/// Create a SmallVector of arguments for a runtime call. +template +llvm::SmallVector +createArguments(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::FunctionType fTy, As... args) { + llvm::SmallVector result; + helper::createArguments<0>(result, builder, loc, fTy, args...); + return result; +} + +} // namespace fir::runtime + +#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RTBUILDER_H diff --git a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h --- a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h +++ b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h @@ -61,6 +61,7 @@ /// Attribute to mark Fortran entities with the CONTIGUOUS attribute. constexpr llvm::StringRef getContiguousAttrName() { return "fir.contiguous"; } + /// Attribute to mark Fortran entities with the OPTIONAL attribute. constexpr llvm::StringRef getOptionalAttrName() { return "fir.optional"; } diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td --- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td +++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td @@ -224,6 +224,28 @@ }]; } +def fir_LLVMPointerType : FIR_Type<"LLVMPointer", "llvm_ptr"> { + let summary = "Like LLVM pointer type"; + + let description = [{ + A pointer type that does not have any of the constraints and semantics + of other FIR pointer types and that translates to llvm pointer types. + It is meant to implement indirection that cannot be expressed directly + in Fortran, but are needed to implement some Fortran features (e.g, + double indirections). + }]; + + let parameters = (ins "mlir::Type":$eleTy); + + let skipDefaultBuilders = 1; + + let builders = [ + TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{ + return Base::get(elementType.getContext(), elementType); + }]>, + ]; +} + def fir_PointerType : FIR_Type<"Pointer", "ptr"> { let summary = "Reference to a POINTER attribute type"; @@ -401,6 +423,12 @@ "mlir::Type":$eleTy), [{ return get(eleTy.getContext(), shape, eleTy, {}); }]>, + TypeBuilderWithInferredContext<(ins + "mlir::Type":$eleTy, + "size_t":$dimensions), [{ + llvm::SmallVector shape(dimensions, getUnknownExtent()); + return get(eleTy.getContext(), shape, eleTy, {}); + }]> ]; let extraClassDeclaration = [{ diff --git a/flang/include/flang/Optimizer/Transforms/Factory.h b/flang/include/flang/Optimizer/Transforms/Factory.h new file mode 100644 --- /dev/null +++ b/flang/include/flang/Optimizer/Transforms/Factory.h @@ -0,0 +1,257 @@ +//===-- Optimizer/Transforms/Factory.h --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Templates to generate more complex code patterns in transformation passes. +// In transformation passes, front-end information such as is available in +// lowering is not available. +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_OPTIMIZER_TRANSFORMS_FACTORY_H +#define FORTRAN_OPTIMIZER_TRANSFORMS_FACTORY_H + +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "llvm/ADT/iterator_range.h" + +namespace mlir { +class Location; +class Value; +} // namespace mlir + +namespace fir::factory { + +constexpr llvm::StringRef attrFortranArrayOffsets() { + return "Fortran.offsets"; +} + +/// Generate a character copy with optimized forms. +/// +/// If the lengths are constant and equal, use load/store rather than a loop. +/// Otherwise, if the lengths are constant and the input is longer than the +/// output, generate a loop to move a truncated portion of the source to the +/// destination. Finally, if the lengths are runtime values or the destination +/// is longer than the source, move the entire source character and pad the +/// destination with spaces as needed. +template +void genCharacterCopy(mlir::Value src, mlir::Value srcLen, mlir::Value dst, + mlir::Value dstLen, B &builder, mlir::Location loc) { + auto srcTy = + fir::dyn_cast_ptrEleTy(src.getType()).template cast(); + auto dstTy = + fir::dyn_cast_ptrEleTy(dst.getType()).template cast(); + if (!srcLen && !dstLen && srcTy.getFKind() == dstTy.getFKind() && + srcTy.getLen() == dstTy.getLen()) { + // same size, so just use load and store + auto load = builder.template create(loc, src); + builder.template create(loc, load, dst); + return; + } + auto zero = builder.template create(loc, 0); + auto one = builder.template create(loc, 1); + auto toArrayTy = [&](fir::CharacterType ty) { + return fir::ReferenceType::get(fir::SequenceType::get( + fir::SequenceType::ShapeRef{fir::SequenceType::getUnknownExtent()}, + fir::CharacterType::getSingleton(ty.getContext(), ty.getFKind()))); + }; + auto toEleTy = [&](fir::ReferenceType ty) { + auto seqTy = ty.getEleTy().cast(); + return seqTy.getEleTy().cast(); + }; + auto toCoorTy = [&](fir::ReferenceType ty) { + return fir::ReferenceType::get(toEleTy(ty)); + }; + if (!srcLen && !dstLen && srcTy.getLen() >= dstTy.getLen()) { + auto upper = builder.template create( + loc, dstTy.getLen() - 1); + auto loop = builder.template create(loc, zero, upper, one); + auto insPt = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(loop.getBody()); + auto csrcTy = toArrayTy(srcTy); + auto csrc = builder.template create(loc, csrcTy, src); + auto in = builder.template create( + loc, toCoorTy(csrcTy), csrc, loop.getInductionVar()); + auto load = builder.template create(loc, in); + auto cdstTy = toArrayTy(dstTy); + auto cdst = builder.template create(loc, cdstTy, dst); + auto out = builder.template create( + loc, toCoorTy(cdstTy), cdst, loop.getInductionVar()); + mlir::Value cast = + srcTy.getFKind() == dstTy.getFKind() + ? load.getResult() + : builder + .template create(loc, toEleTy(cdstTy), load) + .getResult(); + builder.template create(loc, cast, out); + builder.restoreInsertionPoint(insPt); + return; + } + auto minusOne = [&](mlir::Value v) -> mlir::Value { + return builder.template create( + loc, builder.template create(loc, one.getType(), v), + one); + }; + mlir::Value len = dstLen ? minusOne(dstLen) + : builder + .template create( + loc, dstTy.getLen() - 1) + .getResult(); + auto loop = builder.template create(loc, zero, len, one); + auto insPt = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(loop.getBody()); + mlir::Value slen = + srcLen + ? builder.template create(loc, one.getType(), srcLen) + .getResult() + : builder.template create(loc, srcTy.getLen()) + .getResult(); + auto cond = builder.template create( + loc, arith::CmpIPredicate::slt, loop.getInductionVar(), slen); + auto ifOp = builder.template create(loc, cond, /*withElse=*/true); + builder.setInsertionPointToStart(&ifOp.thenRegion().front()); + auto csrcTy = toArrayTy(srcTy); + auto csrc = builder.template create(loc, csrcTy, src); + auto in = builder.template create( + loc, toCoorTy(csrcTy), csrc, loop.getInductionVar()); + auto load = builder.template create(loc, in); + auto cdstTy = toArrayTy(dstTy); + auto cdst = builder.template create(loc, cdstTy, dst); + auto out = builder.template create( + loc, toCoorTy(cdstTy), cdst, loop.getInductionVar()); + mlir::Value cast = + srcTy.getFKind() == dstTy.getFKind() + ? load.getResult() + : builder.template create(loc, toEleTy(cdstTy), load) + .getResult(); + builder.template create(loc, cast, out); + builder.setInsertionPointToStart(&ifOp.elseRegion().front()); + auto space = builder.template create( + loc, toEleTy(cdstTy), llvm::ArrayRef{' '}); + auto cdst2 = builder.template create(loc, cdstTy, dst); + auto out2 = builder.template create( + loc, toCoorTy(cdstTy), cdst2, loop.getInductionVar()); + builder.template create(loc, space, out2); + builder.restoreInsertionPoint(insPt); +} + +/// Get extents from fir.shape/fir.shape_shift op. Empty result if +/// \p shapeVal is empty or is a fir.shift. +/// TODO: ShapeOp and ShapeShiftOp should return `OperandRange` instead of +/// `std::vector` to avoid copies. +inline std::vector getExtents(mlir::Value shapeVal) { + if (shapeVal) + if (auto *shapeOp = shapeVal.getDefiningOp()) { + if (auto shOp = mlir::dyn_cast(shapeOp)) + return shOp.getExtents(); + if (auto shOp = mlir::dyn_cast(shapeOp)) + return shOp.getExtents(); + } + return {}; +} + +/// Get origins from fir.shape_shift/fir.shift op. Empty result if +/// \p shapeVal is empty or is a fir.shape. +inline std::vector getOrigins(mlir::Value shapeVal) { + if (shapeVal) + if (auto *shapeOp = shapeVal.getDefiningOp()) { + if (auto shOp = mlir::dyn_cast(shapeOp)) + return shOp.getOrigins(); + if (auto shOp = mlir::dyn_cast(shapeOp)) + return shOp.getOrigins(); + } + return {}; +} + +/// Convert the normalized indices on array_fetch and array_update to the +/// dynamic (and non-zero) origin required by array_coor. +/// Do not adjust any trailing components in the path as they specify a +/// particular path into the array value and must already correspond to the +/// structure of an element. +template +llvm::SmallVector +originateIndices(mlir::Location loc, B &builder, mlir::Type memTy, + mlir::Value shapeVal, mlir::ValueRange indices) { + llvm::SmallVector result; + auto origins = getOrigins(shapeVal); + if (origins.empty()) { + assert(!shapeVal || mlir::isa(shapeVal.getDefiningOp())); + auto ty = fir::dyn_cast_ptrOrBoxEleTy(memTy); + assert(ty && ty.isa()); + auto seqTy = ty.cast(); + auto one = builder.template create(loc, 1); + const auto dimension = seqTy.getDimension(); + if (shapeVal) { + assert(dimension == mlir::cast(shapeVal.getDefiningOp()) + .getType() + .getRank()); + } + for (auto i : llvm::enumerate(indices)) { + if (i.index() < dimension) { + assert(fir::isa_integer(i.value().getType())); + result.push_back( + builder.template create(loc, i.value(), one)); + } else { + result.push_back(i.value()); + } + } + return result; + } + const auto dimension = origins.size(); + unsigned origOff = 0; + for (auto i : llvm::enumerate(indices)) { + if (i.index() < dimension) + result.push_back(builder.template create( + loc, i.value(), origins[origOff++])); + else + result.push_back(i.value()); + } + return result; +} + +template +llvm::SmallVector createLoopNest( + mlir::Location loc, B &builder, llvm::iterator_range lows, + llvm::iterator_range highs, llvm::iterator_range steps, + llvm::ArrayRef threadedVals, bool unordered = false) { + llvm::SmallVector loops; + llvm::SmallVector inners(threadedVals.begin(), + threadedVals.end()); + for (auto iter0 = lows.begin(), iter1 = highs.begin(), iter2 = steps.begin(); + iter1 != highs.end(); ++iter0, ++iter1, ++iter2) { + auto lp = builder.template create( + loc, *iter0, *iter1, *iter2, unordered, + /*finalCount=*/false, inners); + loops.push_back(lp); + inners.assign(lp.getRegionIterArgs().begin(), lp.getRegionIterArgs().end()); + builder.setInsertionPointToStart(lp.getBody()); + } + auto numLoops = loops.size(); + for (decltype(numLoops) i = 0; i + 1 < numLoops; ++i) { + builder.setInsertionPointToEnd(loops[i].getBody()); + builder.template create(loc, loops[i + 1].getResults()); + } + builder.setInsertionPointAfter(loops[0]); + llvm::errs() << loops[0] << '\n'; + return loops; +} + +template +llvm::SmallVector createLoopNest( + mlir::Location loc, B &builder, llvm::ArrayRef lows, + llvm::ArrayRef highs, llvm::ArrayRef steps, + llvm::ArrayRef threadedVals, bool unordered = false) { + return createLoopNest( + loc, builder, llvm::make_range(lows.begin(), lows.end()), + llvm::make_range(highs.begin(), highs.end()), + llvm::make_range(steps.begin(), steps.end()), threadedVals, unordered); +} + +} // namespace fir::factory + +#endif // FORTRAN_OPTIMIZER_TRANSFORMS_FACTORY_H diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h --- a/flang/include/flang/Optimizer/Transforms/Passes.h +++ b/flang/include/flang/Optimizer/Transforms/Passes.h @@ -28,6 +28,7 @@ std::unique_ptr createAbstractResultOptPass(); std::unique_ptr createAffineDemotionPass(); +std::unique_ptr createArrayValueCopyPass(); std::unique_ptr createFirToCfgPass(); std::unique_ptr createCharacterConversionPass(); std::unique_ptr createExternalNameConversionPass(); diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -74,6 +74,25 @@ ]; } +def ArrayValueCopy : FunctionPass<"array-value-copy"> { + let summary = "Convert array value operations to memory operations."; + let description = [{ + Transform the set of array value primitives to a memory-based array + representation. + + The Ops `array_load`, `array_store`, `array_fetch`, and `array_update` are + used to manage abstract aggregate array values. A simple analysis is done + to determine if there are potential dependences between these operations. + If not, these array operations can be lowered to work directly on the memory + representation. If there is a potential conflict, a temporary is created + along with appropriate copy-in/copy-out operations. Here, a more refined + analysis might be deployed, such as using the affine framework. + + This pass is required before code gen to the LLVM IR dialect. + }]; + let constructor = "::fir::createArrayValueCopyPass()"; +} + def CharacterConversion : Pass<"character-conversion"> { let summary = "Convert CHARACTER entities with different KINDs"; let description = [{ diff --git a/flang/lib/Optimizer/Builder/CMakeLists.txt b/flang/lib/Optimizer/Builder/CMakeLists.txt --- a/flang/lib/Optimizer/Builder/CMakeLists.txt +++ b/flang/lib/Optimizer/Builder/CMakeLists.txt @@ -6,6 +6,7 @@ DoLoopHelper.cpp FIRBuilder.cpp MutableBox.cpp + Runtime/Assign.cpp DEPENDS FIRDialect diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -7,15 +7,20 @@ //===----------------------------------------------------------------------===// #include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Lower/Todo.h" #include "flang/Optimizer/Builder/BoxValue.h" #include "flang/Optimizer/Builder/Character.h" #include "flang/Optimizer/Builder/MutableBox.h" +#include "flang/Optimizer/Builder/Runtime/Assign.h" +#include "flang/Optimizer/Dialect/FIRAttr.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "flang/Optimizer/Support/FatalError.h" #include "flang/Optimizer/Support/InternalNames.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MD5.h" static llvm::cl::opt @@ -110,7 +115,7 @@ const llvm::APFloat &value) { if (fltTy.isa()) { auto attr = getFloatAttr(fltTy, value); - return create(loc, fltTy, attr); + return create(loc, fltTy, attr); } llvm_unreachable("should use builtin floating-point type"); } @@ -257,6 +262,16 @@ return glob; } +mlir::Value fir::FirOpBuilder::convertWithSemantics(mlir::Location loc, + mlir::Type toTy, + mlir::Value val) { + assert(toTy && "store location must be typed"); + auto fromTy = val.getType(); + if (fromTy == toTy) + return val; + return createConvert(loc, toTy, val); +} + mlir::Value fir::FirOpBuilder::createConvert(mlir::Location loc, mlir::Type toTy, mlir::Value val) { if (val.getType() != toTy) { @@ -327,23 +342,84 @@ [&](auto) -> mlir::Value { fir::emitFatalError(loc, "not an array"); }); } -static mlir::Value genNullPointerComparison(fir::FirOpBuilder &builder, - mlir::Location loc, - mlir::Value addr, - arith::CmpIPredicate condition) { +mlir::Value fir::FirOpBuilder::createBox(mlir::Location loc, + const fir::ExtendedValue &exv) { + auto itemAddr = fir::getBase(exv); + if (itemAddr.getType().isa()) + return itemAddr; + auto elementType = fir::dyn_cast_ptrEleTy(itemAddr.getType()); + if (!elementType) + mlir::emitError(loc, "internal: expected a memory reference type ") + << itemAddr.getType(); + auto boxTy = fir::BoxType::get(elementType); + return exv.match( + [&](const fir::ArrayBoxValue &box) -> mlir::Value { + auto s = createShape(loc, exv); + return create(loc, boxTy, itemAddr, s); + }, + [&](const fir::CharArrayBoxValue &box) -> mlir::Value { + auto s = createShape(loc, exv); + if (fir::factory::CharacterExprHelper::hasConstantLengthInType(exv)) + return create(loc, boxTy, itemAddr, s); + + mlir::Value emptySlice; + llvm::SmallVector lenParams{box.getLen()}; + return create(loc, boxTy, itemAddr, s, emptySlice, + lenParams); + }, + [&](const fir::CharBoxValue &box) -> mlir::Value { + if (fir::factory::CharacterExprHelper::hasConstantLengthInType(exv)) + return create(loc, boxTy, itemAddr); + mlir::Value emptyShape, emptySlice; + llvm::SmallVector lenParams{box.getLen()}; + return create(loc, boxTy, itemAddr, emptyShape, + emptySlice, lenParams); + }, + [&](const fir::MutableBoxValue &x) -> mlir::Value { + return create( + loc, fir::factory::getMutableIRBox(*this, loc, x)); + }, + [&](const auto &) -> mlir::Value { + return create(loc, boxTy, itemAddr); + }); +} + +static mlir::Value +genNullPointerComparison(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value addr, + mlir::arith::CmpIPredicate condition) { auto intPtrTy = builder.getIntPtrType(); auto ptrToInt = builder.createConvert(loc, intPtrTy, addr); auto c0 = builder.createIntegerConstant(loc, intPtrTy, 0); - return builder.create(loc, condition, ptrToInt, c0); + return builder.create(loc, condition, ptrToInt, c0); } mlir::Value fir::FirOpBuilder::genIsNotNull(mlir::Location loc, mlir::Value addr) { - return genNullPointerComparison(*this, loc, addr, arith::CmpIPredicate::ne); + return genNullPointerComparison(*this, loc, addr, + mlir::arith::CmpIPredicate::ne); } mlir::Value fir::FirOpBuilder::genIsNull(mlir::Location loc, mlir::Value addr) { - return genNullPointerComparison(*this, loc, addr, arith::CmpIPredicate::eq); + return genNullPointerComparison(*this, loc, addr, + mlir::arith::CmpIPredicate::eq); +} + +mlir::Value fir::FirOpBuilder::genExtentFromTriplet(mlir::Location loc, + mlir::Value lb, + mlir::Value ub, + mlir::Value step, + mlir::Type type) { + auto zero = createIntegerConstant(loc, type, 0); + lb = createConvert(loc, type, lb); + ub = createConvert(loc, type, ub); + step = createConvert(loc, type, step); + auto diff = create(loc, ub, lb); + auto add = create(loc, diff, step); + auto div = create(loc, add, step); + auto cmp = create(loc, mlir::arith::CmpIPredicate::sgt, + div, zero); + return create(loc, cmp, div, zero); } //===--------------------------------------------------------------------===// @@ -376,6 +452,38 @@ }); } +mlir::Value fir::factory::readExtent(fir::FirOpBuilder &builder, + mlir::Location loc, + const fir::ExtendedValue &box, + unsigned dim) { + assert(box.rank() > dim); + return box.match( + [&](const fir::ArrayBoxValue &x) -> mlir::Value { + return x.getExtents()[dim]; + }, + [&](const fir::CharArrayBoxValue &x) -> mlir::Value { + return x.getExtents()[dim]; + }, + [&](const fir::BoxValue &x) -> mlir::Value { + if (!x.getExplicitExtents().empty()) + return x.getExplicitExtents()[dim]; + auto idxTy = builder.getIndexType(); + auto dimVal = builder.createIntegerConstant(loc, idxTy, dim); + return builder + .create(loc, idxTy, idxTy, idxTy, x.getAddr(), + dimVal) + .getResult(1); + }, + [&](const fir::MutableBoxValue &x) -> mlir::Value { + // MutableBoxValue must be read into another category to work with them + // outside of allocation/assignment contexts. + fir::emitFatalError(loc, "readExtents on MutableBoxValue"); + }, + [&](const auto &) -> mlir::Value { + fir::emitFatalError(loc, "extent inquiry on scalar"); + }); +} + llvm::SmallVector fir::factory::readExtents(fir::FirOpBuilder &builder, mlir::Location loc, const fir::BoxValue &box) { @@ -474,3 +582,171 @@ loc, builder.getCharacterLengthType(), str.size()); return fir::CharBoxValue{addr, len}; } + +llvm::SmallVector +fir::factory::createExtents(fir::FirOpBuilder &builder, mlir::Location loc, + fir::SequenceType seqTy) { + llvm::SmallVector extents; + auto idxTy = builder.getIndexType(); + for (auto ext : seqTy.getShape()) + extents.emplace_back( + ext == fir::SequenceType::getUnknownExtent() + ? builder.create(loc, idxTy).getResult() + : builder.createIntegerConstant(loc, idxTy, ext)); + return extents; +} + +// FIXME: This needs some work. To correctly determine the extended value of a +// component, one needs the base object, its type, and its type parameters. (An +// alternative would be to provide an already computed address of the final +// component rather than the base object's address, the point being the result +// will require the address of the final component to create the extended +// value.) One further needs the full path of components being applied. One +// needs to apply type-based expressions to type parameters along this said +// path. (See applyPathToType for a type-only derivation.) Finally, one needs to +// compose the extended value of the terminal component, including all of its +// parameters: array lower bounds expressions, extents, type parameters, etc. +// Any of these properties may be deferred until runtime in Fortran. This +// operation may therefore generate a sizeable block of IR, including calls to +// type-based helper functions, so caching the result of this operation in the +// client would be advised as well. +fir::ExtendedValue fir::factory::componentToExtendedValue( + fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value component) { + auto fieldTy = component.getType(); + if (auto ty = fir::dyn_cast_ptrEleTy(fieldTy)) + fieldTy = ty; + if (fieldTy.isa()) { + llvm::SmallVector nonDeferredTypeParams; + auto eleTy = fir::unwrapSequenceType(fir::dyn_cast_ptrOrBoxEleTy(fieldTy)); + if (auto charTy = eleTy.dyn_cast()) { + auto lenTy = builder.getCharacterLengthType(); + if (charTy.hasConstantLen()) + nonDeferredTypeParams.emplace_back( + builder.createIntegerConstant(loc, lenTy, charTy.getLen())); + // TODO: Starting, F2003, the dynamic character length might be dependent + // on a PDT length parameter. There is no way to make a difference with + // deferred length here yet. + } + if (auto recTy = eleTy.dyn_cast()) + if (recTy.getNumLenParams() > 0) + TODO(loc, "allocatable and pointer components non deferred length " + "parameters"); + + return fir::MutableBoxValue(component, nonDeferredTypeParams, + /*mutableProperties=*/{}); + } + llvm::SmallVector extents; + if (auto seqTy = fieldTy.dyn_cast()) { + fieldTy = seqTy.getEleTy(); + auto idxTy = builder.getIndexType(); + for (auto extent : seqTy.getShape()) { + if (extent == fir::SequenceType::getUnknownExtent()) + TODO(loc, "array component shape depending on length parameters"); + extents.emplace_back(builder.createIntegerConstant(loc, idxTy, extent)); + } + } + if (auto charTy = fieldTy.dyn_cast()) { + auto cstLen = charTy.getLen(); + if (cstLen == fir::CharacterType::unknownLen()) + TODO(loc, "get character component length from length type parameters"); + auto len = builder.createIntegerConstant( + loc, builder.getCharacterLengthType(), cstLen); + if (!extents.empty()) + return fir::CharArrayBoxValue{component, len, extents}; + return fir::CharBoxValue{component, len}; + } + if (auto recordTy = fieldTy.dyn_cast()) + if (recordTy.getNumLenParams() != 0) + TODO(loc, + "lower component ref that is a derived type with length parameter"); + if (!extents.empty()) + return fir::ArrayBoxValue{component, extents}; + return component; +} + +fir::ExtendedValue fir::factory::arrayElementToExtendedValue( + fir::FirOpBuilder &builder, mlir::Location loc, + const fir::ExtendedValue &array, mlir::Value element) { + return array.match( + [&](const fir::CharBoxValue &cb) -> fir::ExtendedValue { + return cb.clone(element); + }, + [&](const fir::CharArrayBoxValue &bv) -> fir::ExtendedValue { + return bv.cloneElement(element); + }, + [&](const fir::BoxValue &box) -> fir::ExtendedValue { + if (box.isCharacter()) { + auto len = fir::factory::readCharLen(builder, loc, box); + return fir::CharBoxValue{element, len}; + } + if (box.isDerivedWithLengthParameters()) + TODO(loc, "get length parameters from derived type BoxValue"); + return element; + }, + [&](const auto &) -> fir::ExtendedValue { return element; }); +} + +/// Can the assignment of this record type be implement with a simple memory +/// copy ? +static bool recordTypeCanBeMemCopied(fir::RecordType recordType) { + if (fir::hasDynamicSize(recordType)) + return false; + for (auto [_, fieldType] : recordType.getTypeList()) { + // Derived type component may have user assignment (so far, we cannot tell + // in FIR, so assume it is always the case, TODO: get the actual info). + if (fir::unwrapSequenceType(fieldType).isa()) + return false; + // Allocatable components need deep copy. + if (auto boxType = fieldType.dyn_cast()) + if (boxType.getEleTy().isa()) + return false; + } + // Constant size components without user defined assignment and pointers can + // be memcopied. + return true; +} + +void fir::factory::genRecordAssignment(fir::FirOpBuilder &builder, + mlir::Location loc, + const fir::ExtendedValue &lhs, + const fir::ExtendedValue &rhs) { + assert(lhs.rank() == 0 && rhs.rank() == 0 && "assume scalar assignment"); + auto baseTy = fir::dyn_cast_ptrOrBoxEleTy(fir::getBase(lhs).getType()); + assert(baseTy && "must be a memory type"); + // Box operands may be polymorphic, it is not entirely clear from 10.2.1.3 + // if the assignment is performed on the dynamic of declared type. Use the + // runtime assuming it is performed on the dynamic type. + bool hasBoxOperands = fir::getBase(lhs).getType().isa() || + fir::getBase(rhs).getType().isa(); + auto recTy = baseTy.dyn_cast(); + assert(recTy && "must be a record type"); + if (hasBoxOperands || !recordTypeCanBeMemCopied(recTy)) { + auto to = fir::getBase(builder.createBox(loc, lhs)); + auto from = fir::getBase(builder.createBox(loc, rhs)); + // The runtime entry point may modify the LHS descriptor if it is + // an allocatable. Allocatable assignment is handle elsewhere in lowering, + // so just create a fir.ref> from the fir.box to comply with the + // runtime interface, but assume the fir.box is unchanged. + // TODO: does this holds true with polymorphic entities ? + auto toMutableBox = builder.createTemporary(loc, to.getType()); + builder.create(loc, to, toMutableBox); + fir::runtime::genAssign(builder, loc, toMutableBox, from); + return; + } + // Otherwise, the derived type has compile time constant size and for which + // the component by component assignment can be replaced by a memory copy. + auto rhsVal = fir::getBase(rhs); + if (fir::isa_ref_type(rhsVal.getType())) + rhsVal = builder.create(loc, rhsVal); + builder.create(loc, rhsVal, fir::getBase(lhs)); +} + +mlir::TupleType +fir::factory::getRaggedArrayHeaderType(fir::FirOpBuilder &builder) { + auto i64Ty = builder.getIntegerType(64); + auto arrTy = fir::SequenceType::get(builder.getIntegerType(8), 1); + auto buffTy = fir::HeapType::get(arrTy); + auto extTy = fir::SequenceType::get(i64Ty, 1); + auto shTy = fir::HeapType::get(extTy); + return mlir::TupleType::get(builder.getContext(), {i64Ty, buffTy, shTy}); +} diff --git a/flang/lib/Optimizer/Builder/Runtime/Assign.cpp b/flang/lib/Optimizer/Builder/Runtime/Assign.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Builder/Runtime/Assign.cpp @@ -0,0 +1,26 @@ +//===-- Assign.cpp -- generate assignment runtime API calls ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/Runtime/Assign.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Builder/Runtime/RTBuilder.h" +#include "flang/Runtime/assign.h" + +using namespace Fortran::runtime; + +void fir::runtime::genAssign(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value destBox, mlir::Value sourceBox) { + auto func = fir::runtime::getRuntimeFunc(loc, builder); + auto fTy = func.getType(); + auto sourceFile = fir::factory::locationToFilename(builder, loc); + auto sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(3)); + auto args = fir::runtime::createArguments(builder, loc, fTy, destBox, + sourceBox, sourceFile, sourceLine); + builder.create(loc, func, args); +} diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -1155,6 +1155,10 @@ // GlobalOp //===----------------------------------------------------------------------===// +mlir::Type fir::GlobalOp::resultType() { + return wrapAllocaResultType(getType()); +} + static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { // Parse the optional linkage llvm::StringRef linkage; @@ -1322,10 +1326,6 @@ // GlobalLenOp //===----------------------------------------------------------------------===// -mlir::Type fir::GlobalOp::resultType() { - return wrapAllocaResultType(getType()); -} - static mlir::ParseResult parseGlobalLenOp(mlir::OpAsmParser &parser, mlir::OperationState &result) { llvm::StringRef fieldName; diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -200,15 +200,15 @@ mlir::Type dyn_cast_ptrEleTy(mlir::Type t) { return llvm::TypeSwitch(t) - .Case( - [](auto p) { return p.getEleTy(); }) + .Case([](auto p) { return p.getEleTy(); }) .Default([](mlir::Type) { return mlir::Type{}; }); } mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) { return llvm::TypeSwitch(t) - .Case( - [](auto p) { return p.getEleTy(); }) + .Case([](auto p) { return p.getEleTy(); }) .Case([](auto p) { auto eleTy = p.getEleTy(); if (auto ty = fir::dyn_cast_ptrEleTy(eleTy)) @@ -471,6 +471,19 @@ printer << getMnemonic() << "<" << getFKind() << '>'; } +//===----------------------------------------------------------------------===// +// LLVMPointerType +//===----------------------------------------------------------------------===// + +// `llvm_ptr` `<` type `>` +mlir::Type fir::LLVMPointerType::parse(mlir::DialectAsmParser &parser) { + return parseTypeSingleton(parser); +} + +void fir::LLVMPointerType::print(mlir::DialectAsmPrinter &printer) const { + printer << getMnemonic() << "<" << getEleTy() << '>'; +} + //===----------------------------------------------------------------------===// // PointerType //===----------------------------------------------------------------------===// @@ -865,7 +878,7 @@ void FIROpsDialect::registerTypes() { addTypes(); + LLVMPointerType, PointerType, RealType, RecordType, ReferenceType, + SequenceType, ShapeType, ShapeShiftType, ShiftType, SliceType, + TypeDescType, fir::VectorType>(); } diff --git a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp @@ -0,0 +1,845 @@ +//===-- ArrayValueCopy.cpp ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "flang/Optimizer/Builder/BoxValue.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Support/FIRContext.h" +#include "flang/Optimizer/Transforms/Factory.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "flang-array-value-copy" + +using namespace fir; + +using OperationUseMapT = llvm::DenseMap; + +namespace { + +/// Array copy analysis. +/// Perform an interference analysis between array values. +/// +/// Lowering will generate a sequence of the following form. +/// ```mlir +/// %a_1 = fir.array_load %array_1(%shape) : ... +/// ... +/// %a_j = fir.array_load %array_j(%shape) : ... +/// ... +/// %a_n = fir.array_load %array_n(%shape) : ... +/// ... +/// %v_i = fir.array_fetch %a_i, ... +/// %a_j1 = fir.array_update %a_j, ... +/// ... +/// fir.array_merge_store %a_j, %a_jn to %array_j : ... +/// ``` +/// +/// The analysis is to determine if there are any conflicts. A conflict is when +/// one the following cases occurs. +/// +/// 1. There is an `array_update` to an array value, a_j, such that a_j was +/// loaded from the same array memory reference (array_j) but with a different +/// shape as the other array values a_i, where i != j. [Possible overlapping +/// arrays.] +/// +/// 2. There is either an array_fetch or array_update of a_j with a different +/// set of index values. [Possible loop-carried dependence.] +/// +/// If none of the array values overlap in storage and the accesses are not +/// loop-carried, then the arrays are conflict-free and no copies are required. +class ArrayCopyAnalysis { +public: + using ConflictSetT = llvm::SmallPtrSet; + using UseSetT = llvm::SmallPtrSet; + using LoadMapSetsT = + llvm::DenseMap>; + + ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); } + + mlir::Operation *getOperation() const { return operation; } + + /// Return true iff the `array_merge_store` has potential conflicts. + bool hasPotentialConflict(mlir::Operation *op) const { + LLVM_DEBUG(llvm::dbgs() + << "looking for a conflict on " << *op + << " and the set has a total of " << conflicts.size() << '\n'); + return conflicts.contains(op); + } + + /// Return the use map. The use map maps array fetch and update operations + /// back to the array load that is the original source of the array value. + const OperationUseMapT &getUseMap() const { return useMap; } + + /// Find all the array operations that access the array value that is loaded + /// by the array load operation, `load`. + const llvm::SmallVector &arrayAccesses(ArrayLoadOp load); + +private: + void construct(mlir::Operation *topLevelOp); + + mlir::Operation *operation; // operation that analysis ran upon + ConflictSetT conflicts; // set of conflicts (loads and merge stores) + OperationUseMapT useMap; + LoadMapSetsT loadMapSets; +}; +} // namespace + +namespace { +/// Helper class to collect all array operations that produced an array value. +class ReachCollector { +private: + // If provided, the `loopRegion` is the body of a loop that produces the array + // of interest. + ReachCollector(llvm::SmallVectorImpl &reach, + mlir::Region *loopRegion) + : reach{reach}, loopRegion{loopRegion} {} + + void collectArrayAccessFrom(mlir::Operation *op, mlir::ValueRange range) { + llvm::errs() << "COLLECT " << *op << "\n"; + if (range.empty()) { + collectArrayAccessFrom(op, mlir::Value{}); + return; + } + for (mlir::Value v : range) + collectArrayAccessFrom(v); + } + + // TODO: Replace recursive algorithm on def-use chain with an iterative one + // with an explicit stack. + void collectArrayAccessFrom(mlir::Operation *op, mlir::Value val) { + // `val` is defined by an Op, process the defining Op. + // If `val` is defined by a region containing Op, we want to drill down + // and through that Op's region(s). + llvm::errs() << "COLLECT " << *op << "\n"; + LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n'); + auto popFn = [&](auto rop) { + assert(val && "op must have a result value"); + auto resNum = val.cast().getResultNumber(); + llvm::SmallVector results; + rop.resultToSourceOps(results, resNum); + for (auto u : results) + collectArrayAccessFrom(u); + }; + if (auto rop = mlir::dyn_cast(op)) { + popFn(rop); + return; + } + if (auto rop = mlir::dyn_cast(op)) { + popFn(rop); + return; + } + if (auto mergeStore = mlir::dyn_cast(op)) { + if (opIsInsideLoops(mergeStore)) + collectArrayAccessFrom(mergeStore.sequence()); + return; + } + + if (mlir::isa(op)) { + // Look for any stores inside the loops, and collect an array operation + // that produced the value being stored to it. + for (mlir::Operation *user : op->getUsers()) + if (auto store = mlir::dyn_cast(user)) + if (opIsInsideLoops(store)) + collectArrayAccessFrom(store.value()); + return; + } + + // Otherwise, Op does not contain a region so just chase its operands. + if (mlir::isa( + op)) { + LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n"); + reach.emplace_back(op); + } + // Array modify assignment is performed on the result. So the analysis + // must look at the what is done with the result. + if (mlir::isa(op)) + for (mlir::Operation *user : op->getResult(0).getUsers()) + followUsers(user); + + for (auto u : op->getOperands()) + collectArrayAccessFrom(u); + } + + void collectArrayAccessFrom(mlir::BlockArgument ba) { + auto *parent = ba.getOwner()->getParentOp(); + // If inside an Op holding a region, the block argument corresponds to an + // argument passed to the containing Op. + auto popFn = [&](auto rop) { + collectArrayAccessFrom(rop.blockArgToSourceOp(ba.getArgNumber())); + }; + if (auto rop = mlir::dyn_cast(parent)) { + popFn(rop); + return; + } + if (auto rop = mlir::dyn_cast(parent)) { + popFn(rop); + return; + } + // Otherwise, a block argument is provided via the pred blocks. + for (auto *pred : ba.getOwner()->getPredecessors()) { + auto u = pred->getTerminator()->getOperand(ba.getArgNumber()); + collectArrayAccessFrom(u); + } + } + + // Recursively trace operands to find all array operations relating to the + // values merged. + void collectArrayAccessFrom(mlir::Value val) { + if (!val || visited.contains(val)) + return; + visited.insert(val); + + // Process a block argument. + if (auto ba = val.dyn_cast()) { + collectArrayAccessFrom(ba); + return; + } + + // Process an Op. + if (auto *op = val.getDefiningOp()) { + collectArrayAccessFrom(op, val); + return; + } + + fir::emitFatalError(val.getLoc(), "unhandled value"); + } + + /// Is \op inside the loop nest region ? + bool opIsInsideLoops(mlir::Operation *op) const { + return loopRegion && loopRegion->isAncestor(op->getParentRegion()); + } + + /// Recursively trace the use of an operation results, calling + /// collectArrayAccessFrom on the direct and indirect user operands. + /// TODO: Replace recursive algorithm on def-use chain with an iterative one + /// with an explicit stack. + void followUsers(mlir::Operation *op) { + for (auto userOperand : op->getOperands()) + collectArrayAccessFrom(userOperand); + // Go through potential converts/coordinate_op. + for (mlir::Operation *indirectUser : op->getUsers()) + followUsers(indirectUser); + } + + llvm::SmallVectorImpl &reach; + llvm::SmallPtrSet visited; + /// Region of the loops nest that produced the array value. + mlir::Region *loopRegion; + +public: + /// Return all ops that produce the array value that is stored into the + /// `array_merge_store`. + static void reachingValues(llvm::SmallVectorImpl &reach, + mlir::Value seq) { + reach.clear(); + mlir::Region *loopRegion = nullptr; + // Only `DoLoopOp` is tested here since array operations are currently only + // associated with this kind of loop. + if (auto doLoop = + mlir::dyn_cast_or_null(seq.getDefiningOp())) + loopRegion = &doLoop->getRegion(0); + ReachCollector collector(reach, loopRegion); + collector.collectArrayAccessFrom(seq); + } +}; +} // namespace + +/// Find all the array operations that access the array value that is loaded by +/// the array load operation, `load`. +const llvm::SmallVector & +ArrayCopyAnalysis::arrayAccesses(ArrayLoadOp load) { + auto lmIter = loadMapSets.find(load); + if (lmIter != loadMapSets.end()) + return lmIter->getSecond(); + + llvm::SmallVector accesses; + UseSetT visited; + llvm::SmallVector queue; // uses of ArrayLoad[orig] + + auto appendToQueue = [&](mlir::Value val) { + for (mlir::OpOperand &use : val.getUses()) + if (!visited.count(&use)) { + visited.insert(&use); + queue.push_back(&use); + } + }; + + // Build the set of uses of `original`. + // let USES = { uses of original fir.load } + appendToQueue(load); + + // Process the worklist until done. + while (!queue.empty()) { + mlir::OpOperand *operand = queue.pop_back_val(); + mlir::Operation *owner = operand->getOwner(); + + auto structuredLoop = [&](auto ro) { + if (auto blockArg = ro.iterArgToBlockArg(operand->get())) { + int64_t arg = blockArg.getArgNumber(); + mlir::Value output = ro.getResult(ro.finalValue() ? arg : arg - 1); + appendToQueue(output); + appendToQueue(blockArg); + } + }; + // TODO: this need to be updated to use the control-flow interface. + auto branchOp = [&](mlir::Block *dest, OperandRange operands) { + if (operands.empty()) + return; + + // Check if this operand is within the range. + unsigned operandIndex = operand->getOperandNumber(); + unsigned operandsStart = operands.getBeginOperandIndex(); + if (operandIndex < operandsStart || + operandIndex >= (operandsStart + operands.size())) + return; + + // Index the successor. + unsigned argIndex = operandIndex - operandsStart; + appendToQueue(dest->getArgument(argIndex)); + }; + // Thread uses into structured loop bodies and return value uses. + if (auto ro = mlir::dyn_cast(owner)) { + structuredLoop(ro); + } else if (auto ro = mlir::dyn_cast(owner)) { + structuredLoop(ro); + } else if (auto rs = mlir::dyn_cast(owner)) { + // Thread any uses of fir.if that return the marked array value. + if (auto ifOp = rs->getParentOfType()) + appendToQueue(ifOp.getResult(operand->getOperandNumber())); + } else if (mlir::isa(owner)) { + // Keep track of array value fetches. + LLVM_DEBUG(llvm::dbgs() + << "add fetch {" << *owner << "} to array value set\n"); + accesses.push_back(owner); + } else if (auto update = mlir::dyn_cast(owner)) { + // Keep track of array value updates and thread the return value uses. + LLVM_DEBUG(llvm::dbgs() + << "add update {" << *owner << "} to array value set\n"); + accesses.push_back(owner); + appendToQueue(update.getResult()); + } else if (auto update = mlir::dyn_cast(owner)) { + // Keep track of array value modification and thread the return value + // uses. + LLVM_DEBUG(llvm::dbgs() + << "add modify {" << *owner << "} to array value set\n"); + accesses.push_back(owner); + appendToQueue(update.getResult(1)); + } else if (auto br = mlir::dyn_cast(owner)) { + branchOp(br.getDest(), br.destOperands()); + } else if (auto br = mlir::dyn_cast(owner)) { + branchOp(br.getTrueDest(), br.getTrueOperands()); + branchOp(br.getFalseDest(), br.getFalseOperands()); + } else if (mlir::isa(owner)) { + // do nothing + } else { + llvm::report_fatal_error("array value reached unexpected op"); + } + } + return loadMapSets.insert({load, accesses}).first->getSecond(); +} + +/// Is there a conflict between the array value that was updated and to be +/// stored to `st` and the set of arrays loaded (`reach`) and used to compute +/// the updated value? +static bool conflictOnLoad(llvm::ArrayRef reach, + ArrayMergeStoreOp st) { + mlir::Value load; + auto addr = st.memref(); + auto stEleTy = fir::dyn_cast_ptrOrBoxEleTy(addr.getType()); + for (auto *op : reach) { + auto ld = mlir::dyn_cast(op); + if (!ld) + continue; + auto ldTy = ld.memref().getType(); + if (auto boxTy = ldTy.dyn_cast()) + ldTy = boxTy.getEleTy(); + if (ldTy.isa() && stEleTy == dyn_cast_ptrEleTy(ldTy)) + return true; + if (ld.memref() == addr) { + if (ld.getResult() != st.original()) + return true; + if (load) + return true; + load = ld; + } + } + return false; +} + +/// Check if there is any potential conflict in the chained update operations +/// (ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp) while merging back to the +/// array. A potential conflict is detected if two operations work on the same +/// indices. +static bool conflictOnMerge(llvm::ArrayRef accesses) { + if (accesses.size() < 2) + return false; + llvm::SmallVector indices; + LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << accesses.size() + << " accesses on the list\n"); + for (auto *op : accesses) { + assert((mlir::isa(op)) && + "unexpected operation in analysis"); + llvm::SmallVector compareVector; + if (auto u = mlir::dyn_cast(op)) { + if (indices.empty()) { + indices = u.indices(); + continue; + } + compareVector = u.indices(); + } else if (auto f = mlir::dyn_cast(op)) { + if (indices.empty()) { + indices = f.indices(); + continue; + } + compareVector = f.indices(); + } else if (auto f = mlir::dyn_cast(op)) { + if (indices.empty()) { + indices = f.indices(); + continue; + } + compareVector = f.indices(); + } + if (compareVector != indices) + return true; + LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n"); + } + return false; +} + +// Are either of types of conflicts present? +inline bool conflictDetected(llvm::ArrayRef reach, + llvm::ArrayRef accesses, + ArrayMergeStoreOp st) { + return conflictOnLoad(reach, st) || conflictOnMerge(accesses); +} + +/// Constructor of the array copy analysis. +/// This performs the analysis and saves the intermediate results. +void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) { + topLevelOp->walk([&](Operation *op) { + if (auto st = mlir::dyn_cast(op)) { + llvm::SmallVector values; + ReachCollector::reachingValues(values, st.sequence()); + const llvm::SmallVector &accesses = + arrayAccesses(mlir::cast(st.original().getDefiningOp())); + if (conflictDetected(values, accesses, st)) { + LLVM_DEBUG(llvm::dbgs() + << "CONFLICT: copies required for " << st << '\n' + << " adding conflicts on: " << op << " and " + << st.original() << '\n'); + conflicts.insert(op); + conflicts.insert(st.original().getDefiningOp()); + } + auto *ld = st.original().getDefiningOp(); + LLVM_DEBUG(llvm::dbgs() + << "map: adding {" << *ld << " -> " << st << "}\n"); + useMap.insert({ld, op}); + } else if (auto load = mlir::dyn_cast(op)) { + const llvm::SmallVector &accesses = + arrayAccesses(load); + LLVM_DEBUG(llvm::dbgs() << "process load: " << load + << ", accesses: " << accesses.size() << '\n'); + for (auto *acc : accesses) { + LLVM_DEBUG(llvm::dbgs() << " access: " << *acc << '\n'); + if (mlir::isa(acc)) { + if (useMap.count(acc)) { + mlir::emitError( + load.getLoc(), + "The parallel semantics of multiple array_merge_stores per " + "array_load are not supported."); + return; + } + LLVM_DEBUG(llvm::dbgs() + << "map: adding {" << *acc << "} -> {" << load << "}\n"); + useMap.insert({acc, op}); + } + } + } + }); +} + +namespace { +class ArrayLoadConversion : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(ArrayLoadOp load, + mlir::PatternRewriter &rewriter) const override { + LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n"); + rewriter.replaceOpWithNewOp(load, load.getType()); + return mlir::success(); + } +}; + +class ArrayMergeStoreConversion + : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(ArrayMergeStoreOp store, + mlir::PatternRewriter &rewriter) const override { + LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n"); + rewriter.eraseOp(store); + return mlir::success(); + } +}; +} // namespace + +static mlir::Type getEleTy(mlir::Type ty) { + if (auto t = dyn_cast_ptrEleTy(ty)) + ty = t; + if (auto t = ty.dyn_cast()) + ty = t.getEleTy(); + // FIXME: keep ptr/heap/ref information. + return ReferenceType::get(ty); +} + +// Extract extents from the ShapeOp/ShapeShiftOp into the result vector. +// TODO: getExtents on op should return a ValueRange instead of a vector. +static void getExtents(llvm::SmallVectorImpl &result, + mlir::Value shape) { + auto *shapeOp = shape.getDefiningOp(); + if (auto s = mlir::dyn_cast(shapeOp)) { + auto e = s.getExtents(); + result.insert(result.end(), e.begin(), e.end()); + return; + } + if (auto s = mlir::dyn_cast(shapeOp)) { + auto e = s.getExtents(); + result.insert(result.end(), e.begin(), e.end()); + return; + } + llvm::report_fatal_error("not a fir.shape/fir.shape_shift op"); +} + +// Place the extents of the array loaded by an ArrayLoadOp into the result +// vector and return a ShapeOp/ShapeShiftOp with the corresponding extents. If +// the ArrayLoadOp is loading a fir.box, code will be generated to read the +// extents from the fir.box, and a the retunred ShapeOp is built with the read +// extents. +// Otherwise, the extents will be extracted from the ShapeOp/ShapeShiftOp +// argument of the ArrayLoadOp that is returned. +static mlir::Value +getOrReadExtentsAndShapeOp(mlir::Location loc, mlir::PatternRewriter &rewriter, + fir::ArrayLoadOp loadOp, + llvm::SmallVectorImpl &result) { + assert(result.empty()); + if (auto boxTy = loadOp.memref().getType().dyn_cast()) { + auto rank = fir::dyn_cast_ptrOrBoxEleTy(boxTy) + .cast() + .getDimension(); + auto idxTy = rewriter.getIndexType(); + for (decltype(rank) dim = 0; dim < rank; ++dim) { + auto dimVal = rewriter.create(loc, dim); + auto dimInfo = rewriter.create(loc, idxTy, idxTy, idxTy, + loadOp.memref(), dimVal); + result.emplace_back(dimInfo.getResult(1)); + } + auto shapeType = fir::ShapeType::get(rewriter.getContext(), rank); + return rewriter.create(loc, shapeType, result); + } + getExtents(result, loadOp.shape()); + return loadOp.shape(); +} + +static mlir::Type toRefType(mlir::Type ty) { + if (fir::isa_ref_type(ty)) + return ty; + return fir::ReferenceType::get(ty); +} + +static mlir::Value +genCoorOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type eleTy, + mlir::Type resTy, mlir::Value alloc, mlir::Value shape, + mlir::Value slice, mlir::ValueRange indices, + mlir::ValueRange typeparams, bool skipOrig = false) { + llvm::SmallVector originated; + if (skipOrig) + originated.assign(indices.begin(), indices.end()); + else + originated = fir::factory::originateIndices(loc, rewriter, alloc.getType(), + shape, indices); + auto seqTy = fir::dyn_cast_ptrOrBoxEleTy(alloc.getType()); + assert(seqTy && seqTy.isa()); + const auto dimension = seqTy.cast().getDimension(); + mlir::Value result = rewriter.create( + loc, eleTy, alloc, shape, slice, + llvm::ArrayRef{originated}.take_front(dimension), + typeparams); + if (dimension < originated.size()) + result = rewriter.create( + loc, resTy, result, + llvm::ArrayRef{originated}.drop_front(dimension)); + return result; +} + +namespace { +/// Conversion of fir.array_update and fir.array_modify Ops. +/// If there is a conflict for the update, then we need to perform a +/// copy-in/copy-out to preserve the original values of the array. If there is +/// no conflict, then it is save to eschew making any copies. +template +class ArrayUpdateConversionBase : public mlir::OpRewritePattern { +public: + explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx, + const ArrayCopyAnalysis &a, + const OperationUseMapT &m) + : mlir::OpRewritePattern{ctx}, analysis{a}, useMap{m} {} + + void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter, + mlir::Value dst, mlir::Value src, mlir::Value shapeOp, + mlir::Type arrTy) const { + auto insPt = rewriter.saveInsertionPoint(); + llvm::SmallVector indices; + llvm::SmallVector extents; + getExtents(extents, shapeOp); + // Build loop nest from column to row. + for (auto sh : llvm::reverse(extents)) { + auto idxTy = rewriter.getIndexType(); + auto ubi = rewriter.create(loc, idxTy, sh); + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto ub = rewriter.create(loc, idxTy, ubi, one); + auto loop = rewriter.create(loc, zero, ub, one); + rewriter.setInsertionPointToStart(loop.getBody()); + indices.push_back(loop.getInductionVar()); + } + // Reverse the indices so they are in column-major order. + std::reverse(indices.begin(), indices.end()); + auto ty = getEleTy(arrTy); + auto fromAddr = rewriter.create( + loc, ty, src, shapeOp, mlir::Value{}, + fir::factory::originateIndices(loc, rewriter, src.getType(), shapeOp, + indices), + mlir::ValueRange{}); + auto load = rewriter.create(loc, fromAddr); + auto toAddr = rewriter.create( + loc, ty, dst, shapeOp, mlir::Value{}, + fir::factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, + indices), + mlir::ValueRange{}); + rewriter.create(loc, load, toAddr); + rewriter.restoreInsertionPoint(insPt); + } + + /// Copy the RHS element into the LHS and insert copy-in/copy-out between a + /// temp and the LHS if the analysis found potential overlaps between the RHS + /// and LHS arrays. The element copy generator must be provided through \p + /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp. + /// Returns the address of the LHS element inside the loop and the LHS + /// ArrayLoad result. + std::pair + materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter, + ArrayOp update, + const std::function &assignElement, + mlir::Type lhsEltRefType) const { + auto *op = update.getOperation(); + mlir::Operation *loadOp = useMap.lookup(op); + auto load = mlir::cast(loadOp); + LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n"); + if (analysis.hasPotentialConflict(loadOp)) { + // If there is a conflict between the arrays, then we copy the lhs array + // to a temporary, update the temporary, and copy the temporary back to + // the lhs array. This yields Fortran's copy-in copy-out array semantics. + LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n"); + rewriter.setInsertionPoint(loadOp); + // Copy in. + llvm::SmallVector extents; + mlir::Value shapeOp = + getOrReadExtentsAndShapeOp(loc, rewriter, load, extents); + auto allocmem = rewriter.create( + loc, dyn_cast_ptrOrBoxEleTy(load.memref().getType()), + load.typeparams(), extents); + genArrayCopy(load.getLoc(), rewriter, allocmem, load.memref(), shapeOp, + load.getType()); + rewriter.setInsertionPoint(op); + mlir::Value coor = genCoorOp( + rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem, + shapeOp, load.slice(), update.indices(), load.typeparams(), + update->hasAttr(fir::factory::attrFortranArrayOffsets())); + assignElement(coor); + auto *storeOp = useMap.lookup(loadOp); + auto store = mlir::cast(storeOp); + rewriter.setInsertionPoint(storeOp); + // Copy out. + genArrayCopy(store.getLoc(), rewriter, store.memref(), allocmem, shapeOp, + load.getType()); + rewriter.create(loc, allocmem); + return {coor, load.getResult()}; + } + // Otherwise, when there is no conflict (a possible loop-carried + // dependence), the lhs array can be updated in place. + LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n"); + rewriter.setInsertionPoint(op); + auto coorTy = getEleTy(load.getType()); + mlir::Value coor = genCoorOp( + rewriter, loc, coorTy, lhsEltRefType, load.memref(), load.shape(), + load.slice(), update.indices(), load.typeparams(), + update->hasAttr(fir::factory::attrFortranArrayOffsets())); + assignElement(coor); + return {coor, load.getResult()}; + } + +private: + const ArrayCopyAnalysis &analysis; + const OperationUseMapT &useMap; +}; + +class ArrayUpdateConversion : public ArrayUpdateConversionBase { +public: + explicit ArrayUpdateConversion(mlir::MLIRContext *ctx, + const ArrayCopyAnalysis &a, + const OperationUseMapT &m) + : ArrayUpdateConversionBase{ctx, a, m} {} + + mlir::LogicalResult + matchAndRewrite(ArrayUpdateOp update, + mlir::PatternRewriter &rewriter) const override { + auto loc = update.getLoc(); + auto assignElement = [&](mlir::Value coor) { + mlir::Value input = update.merge(); + if (auto inEleTy = fir::dyn_cast_ptrEleTy(input.getType())) { + if (inEleTy.isa()) { + fir::FirOpBuilder builder( + rewriter, + fir::getKindMapping(update->getParentOfType())); + if (!update.typeparams().empty()) { + auto boxTy = fir::BoxType::get(inEleTy); + mlir::Value emptyShape, emptySlice; + auto lhs = rewriter.create( + loc, boxTy, coor, emptyShape, emptySlice, update.typeparams()); + auto rhs = rewriter.create( + loc, boxTy, input, emptyShape, emptySlice, update.typeparams()); + fir::factory::genRecordAssignment(builder, loc, fir::BoxValue(lhs), + fir::BoxValue(rhs)); + } else { + fir::factory::genRecordAssignment(builder, loc, coor, input); + } + } else { + llvm::report_fatal_error("not a legal reference type"); + } + } else { + rewriter.create(loc, input, coor); + } + }; + auto lhsEltRefType = toRefType(update.merge().getType()); + auto [_, lhsLoadResult] = materializeAssignment( + loc, rewriter, update, assignElement, lhsEltRefType); + update.replaceAllUsesWith(lhsLoadResult); + rewriter.replaceOp(update, lhsLoadResult); + return mlir::success(); + } +}; + +class ArrayModifyConversion : public ArrayUpdateConversionBase { +public: + explicit ArrayModifyConversion(mlir::MLIRContext *ctx, + const ArrayCopyAnalysis &a, + const OperationUseMapT &m) + : ArrayUpdateConversionBase{ctx, a, m} {} + + mlir::LogicalResult + matchAndRewrite(ArrayModifyOp modify, + mlir::PatternRewriter &rewriter) const override { + auto loc = modify.getLoc(); + auto assignElement = [](mlir::Value) { + // Assignment already materialized by lowering using lhs element address. + }; + auto lhsEltRefType = modify.getResult(0).getType(); + auto [lhsEltCoor, lhsLoadResult] = materializeAssignment( + loc, rewriter, modify, assignElement, lhsEltRefType); + modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult}); + rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult}); + return mlir::success(); + } +}; + +class ArrayFetchConversion : public mlir::OpRewritePattern { +public: + explicit ArrayFetchConversion(mlir::MLIRContext *ctx, + const OperationUseMapT &m) + : OpRewritePattern{ctx}, useMap{m} {} + + mlir::LogicalResult + matchAndRewrite(ArrayFetchOp fetch, + mlir::PatternRewriter &rewriter) const override { + auto *op = fetch.getOperation(); + rewriter.setInsertionPoint(op); + auto load = mlir::cast(useMap.lookup(op)); + auto loc = fetch.getLoc(); + mlir::Value coor = + genCoorOp(rewriter, loc, getEleTy(load.getType()), + toRefType(fetch.getType()), load.memref(), load.shape(), + load.slice(), fetch.indices(), load.typeparams(), + fetch->hasAttr(fir::factory::attrFortranArrayOffsets())); + rewriter.replaceOpWithNewOp(fetch, coor); + return mlir::success(); + } + +private: + const OperationUseMapT &useMap; +}; +} // namespace + +namespace { +class ArrayValueCopyConverter + : public ArrayValueCopyBase { +public: + void runOnFunction() override { + auto func = getFunction(); + LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '" + << func.getName() << "'\n"); + auto *context = &getContext(); + + // Perform the conflict analysis. + auto &analysis = getAnalysis(); + const auto &useMap = analysis.getUseMap(); + + // Phase 1 is performing a rewrite on the array accesses. Once all the + // array accesses are rewritten we can go on phase 2. + // Phase 2 gets rid of the useless copy-in/copyout operations. The copy-in + // /copy-out refers the Fortran copy-in/copy-out semantics on statements. + mlir::OwningRewritePatternList patterns1(context); + patterns1.insert(context, useMap); + patterns1.insert(context, analysis, useMap); + patterns1.insert(context, analysis, useMap); + mlir::ConversionTarget target(*context); + target.addLegalDialect(); + target.addIllegalOp(); + // Rewrite the array fetch and array update ops. + if (mlir::failed( + mlir::applyPartialConversion(func, target, std::move(patterns1)))) { + mlir::emitError(mlir::UnknownLoc::get(context), + "failure in array-value-copy pass, phase 1"); + signalPassFailure(); + } + + mlir::OwningRewritePatternList patterns2(context); + patterns2.insert(context); + patterns2.insert(context); + target.addIllegalOp(); + if (mlir::failed( + mlir::applyPartialConversion(func, target, std::move(patterns2)))) { + mlir::emitError(mlir::UnknownLoc::get(context), + "failure in array-value-copy pass, phase 2"); + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr fir::createArrayValueCopyPass() { + return std::make_unique(); +} diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -3,16 +3,19 @@ AffinePromotion.cpp AffineDemotion.cpp CharacterConversion.cpp + ArrayValueCopy.cpp Inliner.cpp ExternalNameConversion.cpp RewriteLoop.cpp DEPENDS + FIRBuilder FIRDialect FIRSupport FIROptTransformsPassIncGen LINK_LIBS + FIRBuilder FIRDialect MLIRAffineToStandard MLIRLLVMIR diff --git a/flang/test/Fir/array-value-copy.fir b/flang/test/Fir/array-value-copy.fir new file mode 100644 --- /dev/null +++ b/flang/test/Fir/array-value-copy.fir @@ -0,0 +1,535 @@ +// Test for the array-value-copy-pass +// RUN: fir-opt --split-input-file --array-value-copy %s | FileCheck %s + +// Test simple fir.array_load/fir.array_fetch conversion to fir.array_coor +func @array_fetch_conversion(%arr1 : !fir.ref>, %m: index, %n: index) { + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + %s = fir.shape %m, %n : (index, index) -> !fir.shape<2> + %av1 = fir.array_load %arr1(%s) : (!fir.ref>, !fir.shape<2>) -> !fir.array + %f = fir.array_fetch %av1, %c10, %c20 : (!fir.array, index, index) -> f32 + return +} + +// CHECK-LABEL: func @array_fetch_conversion( +// CHECK-SAME: %[[ARRAY:.*]]: !fir.ref>, +// CHECK-SAME: %[[ARG1:.*]]: index, +// CHECK-SAME: %[[ARG2:.*]]: index) { +// CHECK: %{{.*}} = fir.shape %[[ARG1]], %[[ARG2]] : (index, index) -> !fir.shape<2> +// CHECK: %{{.*}} = fir.undefined !fir.array +// CHECK: %[[VAL_0:.*]] = arith.addi %{{.*}}, %{{.*}} : index +// CHECK: %[[VAL_1:.*]] = arith.addi %{{.*}}, %{{.*}} : index +// CHECK-NOT: fir.array_load +// CHECK-NOT: fir.array_fetch +// CHECK: %{{.*}} = fir.array_coor %arg0(%0) %[[VAL_0]], %[[VAL_1]] : (!fir.ref>, !fir.shape<2>, index, index) -> !fir.ref +// CHECK: %{{.*}} = fir.load %4 : !fir.ref + +// ----- + +// Test simple fir.array_load/fir.array_update conversion without copy-in/copy-out +func @array_update_conversion(%arr1 : !fir.box>, %m: index, %n: index) { + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + %c1 = arith.constant 1 : index + %f = arith.constant 2.0 : f32 + %s = fir.shape %m, %n : (index, index) -> !fir.shape<2> + %av1 = fir.array_load %arr1(%s) : (!fir.box>, !fir.shape<2>) -> !fir.array + %av2 = fir.array_update %av1, %f, %c1, %c1 : (!fir.array, f32, index, index) -> !fir.array + return +} + +// CHECK-LABEL: func @array_update_conversion +// CHECK-NOT: fir.array_load +// CHECK-NOT: fir.array_update +// CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : index +// CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : index +// CHECK: %[[ARRAY_COOR:.*]] = fir.array_coor{{.*}}-> !fir.ref +// CHECK: fir.store %{{.*}} to %[[ARRAY_COOR]] : !fir.ref + +// ----- + +// Test simple fir.array_load/fir.array_update conversion without copy-in/copy-out +func @array_update_conversion(%arr1 : !fir.box>, %m: index, %n: index, %cond: i1) { + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + %c1 = arith.constant 1 : index + %f = arith.constant 2.0 : f32 + %g = arith.constant 4.0 : f32 + %s = fir.shape %m, %n : (index, index) -> !fir.shape<2> + %av1 = fir.array_load %arr1(%s) : (!fir.box>, !fir.shape<2>) -> !fir.array + fir.if %cond { + %av2 = fir.array_update %av1, %f, %c1, %c1 : (!fir.array, f32, index, index) -> !fir.array + } else { + %av2 = fir.array_update %av1, %g, %c1, %c1 : (!fir.array, f32, index, index) -> !fir.array + } + return +} + +// ----- + +// Test fir.array_load/fir.array_fetch/fir.array_update conversion with +// an introduced copy-in/copy-out. +// +// This test corresponds to a simplified FIR version of the following Fortran +// code. +// ``` +// integer :: i(10) +// i = i(10:1:-1) +// end +// ``` + +func @conversion_with_temporary(%arr0 : !fir.ref>) { + %c10 = arith.constant 10 : index + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2 = fir.array_load %arr0(%1) : (!fir.ref>, !fir.shape<1>) -> !fir.array<10xi32> + %c10_i64 = arith.constant 10 : i64 + %3 = fir.convert %c10_i64 : (i64) -> index + %c1_i64 = arith.constant 1 : i64 + %c-1_i64 = arith.constant -1 : i64 + %4 = fir.shape %c10 : (index) -> !fir.shape<1> + %5 = fir.slice %c10_i64, %c1_i64, %c-1_i64 : (i64, i64, i64) -> !fir.slice<1> + %6 = fir.array_load %arr0(%4) [%5] : (!fir.ref>, !fir.shape<1>, !fir.slice<1>) -> !fir.array<10xi32> + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %7 = arith.subi %3, %c1 : index + %8 = fir.do_loop %arg0 = %c0 to %7 step %c1 unordered iter_args(%arg1 = %2) -> (!fir.array<10xi32>) { + %9 = fir.array_fetch %6, %arg0 : (!fir.array<10xi32>, index) -> i32 + %10 = fir.array_update %arg1, %9, %arg0 : (!fir.array<10xi32>, i32, index) -> !fir.array<10xi32> + fir.result %10 : !fir.array<10xi32> + } + fir.array_merge_store %2, %8 to %arr0 : !fir.array<10xi32>, !fir.array<10xi32>, !fir.ref> + return +} + +// CHECK-LABEL: func @conversion_with_temporary( +// CHECK-SAME: %[[ARR0:.*]]: !fir.ref>) +// Allocation of temporary array. +// CHECK: %[[TEMP:.*]] = fir.allocmem !fir.array<10xi32>, %{{.*}} +// Copy of original array to temp. +// CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: %[[COOR0:.*]] = fir.array_coor %[[ARR0]](%{{.*}}) %{{.*}} : (!fir.ref>, !fir.shape<1>, index) -> !fir.ref +// CHECK: %[[LOAD0:.*]] = fir.load %[[COOR0]] : !fir.ref +// CHECK: %[[COOR1:.*]] = fir.array_coor %[[TEMP]](%{{.*}}) %{{.*}} : (!fir.heap>, !fir.shape<1>, index) -> !fir.ref +// CHECK: fir.store %[[LOAD0]] to %[[COOR1]] : !fir.ref +// CHECK: } +// Perform the assignment i = i(10:1:-1) using the temporary array. +// CHECK: %{{.*}} = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered iter_args(%{{.*}} = %{{.*}}) -> (!fir.array<10xi32>) { +// CHECK-NOT: %{{.*}} = fir.array_fetch +// CHECK-NOT: %{{.*}} = fir.array_update +// CHECK: %[[COOR0:.*]] = fir.array_coor %[[ARR0]](%{{.*}}) [%{{.*}}] %{{.*}} : (!fir.ref>, !fir.shape<1>, !fir.slice<1>, index) -> !fir.ref +// CHECK: %[[LOAD0:.*]] = fir.load %[[COOR0]] : !fir.ref +// CHECK: %[[COOR1:.*]] = fir.array_coor %[[TEMP]](%{{.*}}) %{{.*}} : (!fir.heap>, !fir.shape<1>, index) -> !fir.ref +// CHECK: fir.store %[[LOAD0]] to %[[COOR1]] : !fir.ref +// CHECK: fir.result %{{.*}} : !fir.array<10xi32> +// CHECK: } +// Copy the result back to the original array. +// CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: %[[COOR0:.*]] = fir.array_coor %[[TEMP]](%{{.*}}) %{{.*}} : (!fir.heap>, !fir.shape<1>, index) -> !fir.ref +// CHECK: %[[LOAD0:.*]] = fir.load %[[COOR0:.*]] : !fir.ref +// CHECK: %[[COOR1:.*]] = fir.array_coor %[[ARR0]](%{{.*}}) %{{.*}} : (!fir.ref>, !fir.shape<1>, index) -> !fir.ref +// CHECK: fir.store %[[LOAD0]] to %[[COOR1]] : !fir.ref +// CHECK: } +// Free temporary array. +// CHECK: fir.freemem %[[TEMP]] : !fir.heap> + +// ----- + +// Test fir.array_load/fir.array_fetch/fir.array_update conversion with +// an introduced copy-in/copy-out on a multidimensional array. + +func @conversion_with_temporary_multidim(%0: !fir.ref>) { + %c10 = arith.constant 10 : index + %c5 = arith.constant 5 : index + %1 = fir.shape %c10, %c5 : (index, index) -> !fir.shape<2> + %2 = fir.array_load %0(%1) : (!fir.ref>, !fir.shape<2>) -> !fir.array<10x5xi32> + %c10_i64 = arith.constant 10 : i64 + %3 = fir.convert %c10_i64 : (i64) -> index + %c5_i64 = arith.constant 5 : i64 + %4 = fir.convert %c5_i64 : (i64) -> index + %c1 = arith.constant 1 : index + %c10_i64_0 = arith.constant 10 : i64 + %c1_i64 = arith.constant 1 : i64 + %c-1_i64 = arith.constant -1 : i64 + %5 = arith.addi %c1, %c5 : index + %6 = arith.subi %5, %c1 : index + %c1_i64_1 = arith.constant 1 : i64 + %7 = fir.shape %c10, %c5 : (index, index) -> !fir.shape<2> + %8 = fir.slice %c10_i64_0, %c1_i64, %c-1_i64, %c1, %6, %c1_i64_1 : (i64, i64, i64, index, index, i64) -> !fir.slice<2> + %9 = fir.array_load %0(%7) [%8] : (!fir.ref>, !fir.shape<2>, !fir.slice<2>) -> !fir.array<10x5xi32> + %c1_2 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %10 = arith.subi %3, %c1_2 : index + %11 = arith.subi %4, %c1_2 : index + %12 = fir.do_loop %arg0 = %c0 to %11 step %c1_2 unordered iter_args(%arg1 = %2) -> (!fir.array<10x5xi32>) { + %13 = fir.do_loop %arg2 = %c0 to %10 step %c1_2 unordered iter_args(%arg3 = %arg1) -> (!fir.array<10x5xi32>) { + %14 = fir.array_fetch %9, %arg2, %arg0 : (!fir.array<10x5xi32>, index, index) -> i32 + %15 = fir.array_update %arg3, %14, %arg2, %arg0 : (!fir.array<10x5xi32>, i32, index, index) -> !fir.array<10x5xi32> + fir.result %15 : !fir.array<10x5xi32> + } + fir.result %13 : !fir.array<10x5xi32> + } + fir.array_merge_store %2, %12 to %0 : !fir.array<10x5xi32>, !fir.array<10x5xi32>, !fir.ref> + return +} + +// CHECK-LABEL: func @conversion_with_temporary_multidim( +// CHECK-SAME: %[[ARR0:.*]]: !fir.ref>) { +// CHECK: %[[CST10:.*]] = arith.constant 10 : index +// CHECK: %[[CST5:.*]] = arith.constant 5 : index +// CHECK: %[[TEMP:.*]] = fir.allocmem !fir.array<10x5xi32>, %c10, %c5 +// CHECK: %[[IDX5:.*]] = fir.convert %[[CST5]] : (index) -> index +// CHECK: %[[UB5:.*]] = arith.subi %[[IDX5]], %{{.*}} : index +// CHECK: fir.do_loop %[[INDUC0:.*]] = %{{.*}} to %[[UB5]] step %{{.*}} { +// CHECK: %[[IDX10:.*]] = fir.convert %[[CST10]] : (index) -> index +// CHECK: %[[UB10:.*]] = arith.subi %[[IDX10]], %{{.*}} : index +// CHECK: fir.do_loop %[[INDUC1:.*]] = %{{.*}} to %[[UB10]] step %{{.*}} { +// CHECK: %[[IDX1:.*]] = arith.addi %[[INDUC1]], %{{.*}} : index +// CHECK: %[[IDX2:.*]] = arith.addi %[[INDUC0]], %{{.*}} : index +// CHECK: %[[COOR0:.*]] = fir.array_coor %[[ARR0]](%{{.*}}) %[[IDX1:.*]], %[[IDX2:.*]] : (!fir.ref>, !fir.shape<2>, index, index) -> !fir.ref +// CHECK: %[[LOAD0:.*]] = fir.load %[[COOR0]] : !fir.ref +// CHECK: %[[COOR1:.*]] = fir.array_coor %[[TEMP]](%{{.*}}) %{{.*}}, %{{.*}} : (!fir.heap>, !fir.shape<2>, index, index) -> !fir.ref +// CHECK: fir.store %[[LOAD0]] to %[[COOR1]] : !fir.ref +// CHECK: %{{.*}} = fir.do_loop %[[INDUC0:.*]] = %{{.*}} to %{{.*}} step %{{.*}} unordered iter_args(%{{.*}} = %{{.*}}) -> (!fir.array<10x5xi32>) { +// CHECK: %{{.*}} = fir.do_loop %[[INDUC1:.*]] = %{{.*}} to %{{.*}} step %{{.*}} unordered iter_args(%{{.*}} = %{{.*}}) -> (!fir.array<10x5xi32>) { +// CHECK: %[[IDX1:.*]] = arith.addi %[[INDUC1]], %{{.*}} : index +// CHECK: %[[IDX2:.*]] = arith.addi %[[INDUC0]], %{{.*}} : index +// CHECK-NOT: %{{.*}} = fir.array_fetch +// CHECK-NOT: %{{.*}} = fir.array_update +// CHECK: %[[COOR0:.*]] = fir.array_coor %[[ARR0]](%{{.*}}) [%{{.*}}] %[[IDX1]], %[[IDX2]] : (!fir.ref>, !fir.shape<2>, !fir.slice<2>, index, index) -> !fir.ref +// CHECK: %[[LOAD0:.*]] = fir.load %[[COOR0]] : !fir.ref +// CHECK: %[[COOR1:.*]] = fir.array_coor %[[TEMP]](%{{.*}}) %{{.*}}, %{{.*}} : (!fir.heap>, !fir.shape<2>, index, index) -> !fir.ref +// CHECK: fir.store %[[LOAD0]] to %[[COOR1]] : !fir.ref +// CHECK: %[[IDX5:.*]] = fir.convert %[[CST5]] : (index) -> index +// CHECK: %[[UB5:.*]] = arith.subi %[[IDX5]], %{{.*}} : index +// CHECK: fir.do_loop %[[INDUC0:.*]] = %{{.*}} to %[[UB5]] step %{{.*}} { +// CHECK: %[[IDX10:.*]] = fir.convert %[[CST10]] : (index) -> index +// CHECK: %[[UB10:.*]] = arith.subi %[[IDX10]], %{{.*}} : index +// CHECK: fir.do_loop %[[INDUC1:.*]] = %{{.*}} to %[[UB10]] step %{{.*}} { +// CHECK: %[[IDX1:.*]] = arith.addi %[[INDUC1]], %{{.*}} : index +// CHECK: %[[IDX2:.*]] = arith.addi %[[INDUC0]], %{{.*}} : index +// CHECK: %[[COOR0:.*]] = fir.array_coor %[[TEMP]](%{{.*}}) %[[IDX1]], %[[IDX2]] : (!fir.heap>, !fir.shape<2>, index, index) -> !fir.ref +// CHECK: %[[LOAD0:.*]] = fir.load %[[COOR0]] : !fir.ref +// CHECK: %[[COOR1:.*]] = fir.array_coor %[[ARR0]](%{{.*}}) %{{.*}}, %{{.*}} : (!fir.ref>, !fir.shape<2>, index, index) -> !fir.ref +// CHECK: fir.store %[[LOAD0]] to %[[COOR1]] : !fir.ref +// CHECK: fir.freemem %[[TEMP]] : !fir.heap> + +// ----- + +// Test fir.array_modify conversion with no overlap. +func @array_modify_no_overlap(%arg0: !fir.ref>, %arg1: !fir.ref>) { + %c100 = arith.constant 100 : index + %c99 = arith.constant 99 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = fir.alloca f32 + %1 = fir.shape %c100 : (index) -> !fir.shape<1> + %2 = fir.array_load %arg0(%1) : (!fir.ref>, !fir.shape<1>) -> !fir.array<100xf32> + %3 = fir.array_load %arg1(%1) : (!fir.ref>, !fir.shape<1>) -> !fir.array<100xf32> + %4 = fir.do_loop %arg2 = %c0 to %c99 step %c1 unordered iter_args(%arg3 = %2) -> (!fir.array<100xf32>) { + %5 = fir.array_fetch %3, %arg2 : (!fir.array<100xf32>, index) -> f32 + %6:2 = fir.array_modify %arg3, %arg2 : (!fir.array<100xf32>, index) -> (!fir.ref, !fir.array<100xf32>) + fir.store %5 to %0 : !fir.ref + fir.call @user_defined_assignment(%6#0, %0) : (!fir.ref, !fir.ref) -> () + fir.result %6#1 : !fir.array<100xf32> + } + fir.array_merge_store %2, %4 to %arg0 : !fir.array<100xf32>, !fir.array<100xf32>, !fir.ref> + return +} + +func private @user_defined_assignment(!fir.ref, !fir.ref) + +// CHECK-LABEL: func @array_modify_no_overlap( +// CHECK-SAME: %[[ARR0:.*]]: !fir.ref>, +// CHECK-SAME: %[[ARR1:.*]]: !fir.ref>) { +// CHECK: %[[VAR0:.*]] = fir.alloca f32 +// CHECK-COUNT-1: %{{.*}} = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered iter_args(%{{.*}} = %{{.*}}) -> (!fir.array<100xf32>) { +// CHECK-NOT: %{{.*}} = fir.array_fetch +// CHECK-NOT: %{{.*}} = fir.array_modify +// CHECK: %[[COOR0:.*]] = fir.array_coor %arg1(%1) %5 : (!fir.ref>, !fir.shape<1>, index) -> !fir.ref +// CHECK: %[[LOAD0:.*]] = fir.load %[[COOR0]] : !fir.ref +// CHECK: %[[COOR1:.*]] = fir.array_coor %[[ARR0]](%{{.*}}) %{{.*}} : (!fir.ref>, !fir.shape<1>, index) -> !fir.ref +// CHECK: fir.store %[[LOAD0]] to %[[VAR0]] : !fir.ref +// CHECK: fir.call @{{.*}}(%[[COOR1]], %[[VAR0]]) : (!fir.ref, !fir.ref) -> () + +// ----- + +// Test fir.array_modify conversion with an overlap. +// Test user_defined_assignment(arg0(:), arg0(100:1:-1)) +func @array_modify_overlap(%arg0: !fir.ref>) { + %c100 = arith.constant 100 : index + %c99 = arith.constant 99 : index + %c1 = arith.constant 1 : index + %c-1 = arith.constant -1 : index + %c0 = arith.constant 0 : index + %0 = fir.alloca f32 + %1 = fir.shape %c100 : (index) -> !fir.shape<1> + %2 = fir.array_load %arg0(%1) : (!fir.ref>, !fir.shape<1>) -> !fir.array<100xf32> + %3 = fir.slice %c100, %c1, %c-1 : (index, index, index) -> !fir.slice<1> + %4 = fir.array_load %arg0(%1) [%3] : (!fir.ref>, !fir.shape<1>, !fir.slice<1>) -> !fir.array<100xf32> + %5 = fir.do_loop %arg1 = %c0 to %c99 step %c1 unordered iter_args(%arg2 = %2) -> (!fir.array<100xf32>) { + %6 = fir.array_fetch %4, %arg1 : (!fir.array<100xf32>, index) -> f32 + %7:2 = fir.array_modify %arg2, %arg1 : (!fir.array<100xf32>, index) -> (!fir.ref, !fir.array<100xf32>) + fir.store %6 to %0 : !fir.ref + fir.call @user_defined_assignment(%7#0, %0) : (!fir.ref, !fir.ref) -> () + fir.result %7#1 : !fir.array<100xf32> + } + fir.array_merge_store %2, %5 to %arg0 : !fir.array<100xf32>, !fir.array<100xf32>, !fir.ref> + return +} + +func private @user_defined_assignment(!fir.ref, !fir.ref) + +// CHECK-LABEL: func @array_modify_overlap( +// CHECK-SAME: %[[ARR0:.*]]: !fir.ref>) { +// CHECK: %[[VAR0:.*]] = fir.alloca f32 +// Allocate the temporary array. +// CHECK: %[[TEMP:.*]] = fir.allocmem !fir.array<100xf32>, %{{.*}} +// Copy original array to temp. +// CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: %[[COOR0:.*]] = fir.array_coor %[[ARR0]](%{{.*}}) %{{.*}} : (!fir.ref>, !fir.shape<1>, index) -> !fir.ref +// CHECK: %[[LOAD0:.*]] = fir.load %[[COOR0]] : !fir.ref +// CHECK: %[[COOR1:.*]] = fir.array_coor %[[TEMP]](%{{.*}}) %{{.*}} : (!fir.heap>, !fir.shape<1>, index) -> !fir.ref +// CHECK: fir.store %[[LOAD0]] to %[[COOR1]] : !fir.ref +// CHECK: } +// CHECK: %[[VAL_21:.*]] = fir.undefined !fir.array<100xf32> +// CHECK: %[[VAL_23:.*]] = fir.undefined !fir.array<100xf32> +// CHECK-NOT: %{{.*}} = fir.array_fetch +// CHECK-NOT: %{{.*}} = fir.array_modify +// CHECK: %[[COOR0:.*]] = fir.array_coor %[[ARR0]](%{{.*}}) {{\[}}%{{.*}}] %{{.*}} : (!fir.ref>, !fir.shape<1>, !fir.slice<1>, index) -> !fir.ref +// CHECK: %[[LOAD0:.*]] = fir.load %[[COOR0]] : !fir.ref +// CHECK: %[[COOR1:.*]] = fir.array_coor %[[TEMP]](%{{.*}}) %{{.*}} : (!fir.heap>, !fir.shape<1>, index) -> !fir.ref +// CHECK: fir.store %[[LOAD0]] to %[[VAR0]] : !fir.ref +// CHECK: fir.call @user_defined_assignment(%[[COOR1]], %[[VAR0]]) : (!fir.ref, !fir.ref) -> () +// CHECK: } +// Copy back result to original array from temp. +// CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: %[[COOR0:.*]] = fir.array_coor %[[TEMP]](%{{.*}}) %{{.*}} : (!fir.heap>, !fir.shape<1>, index) -> !fir.ref +// CHECK: %[[LOAD0:.*]] = fir.load %[[COOR0]] : !fir.ref +// CHECK: %[[COOR1:.*]] = fir.array_coor %[[ARR0]](%{{.*}}) %{{.*}} : (!fir.ref>, !fir.shape<1>, index) -> !fir.ref +// CHECK: fir.store %[[LOAD0]] to %[[COOR1]] : !fir.ref +// CHECK: } +// Free the temporary array. +// CHECK: fir.freemem %[[TEMP]] : !fir.heap> +// CHECK: return +// CHECK: } + +// ----- + +// Test array of types with no overlap +func @array_of_types() { + %0 = fir.alloca i32 {bindc_name = "j", uniq_name = "_QEj"} + %1 = fir.address_of(@_QEtypes) : !fir.ref}>>> + %c1_i32 = arith.constant 1 : i32 + %2 = fir.convert %c1_i32 : (i32) -> index + %c10_i32 = arith.constant 10 : i32 + %3 = fir.convert %c10_i32 : (i32) -> index + %c1 = arith.constant 1 : index + %4 = fir.do_loop %arg0 = %2 to %3 step %c1 -> index { + %6 = fir.convert %arg0 : (index) -> i32 + fir.store %6 to %0 : !fir.ref + %c1_0 = arith.constant 1 : index + %7 = fir.load %0 : !fir.ref + %8 = fir.convert %7 : (i32) -> i64 + %c1_i64 = arith.constant 1 : i64 + %9 = arith.subi %8, %c1_i64 : i64 + %10 = fir.coordinate_of %1, %9 : (!fir.ref}>>>, i64) -> !fir.ref}>> + %11 = fir.field_index i, !fir.type<_QTd{i:!fir.array<10xi32>}> + %12 = fir.coordinate_of %10, %11 : (!fir.ref}>>, !fir.field) -> !fir.ref> + %c10 = arith.constant 10 : index + %13 = arith.addi %c1_0, %c10 : index + %14 = arith.subi %13, %c1_0 : index + %c1_i64_1 = arith.constant 1 : i64 + %15 = fir.shape %c10 : (index) -> !fir.shape<1> + %16 = fir.slice %c1_0, %14, %c1_i64_1 : (index, index, i64) -> !fir.slice<1> + %17 = fir.array_load %12(%15) [%16] : (!fir.ref>, !fir.shape<1>, !fir.slice<1>) -> !fir.array<10xi32> + %c10_i64 = arith.constant 10 : i64 + %18 = fir.convert %c10_i64 : (i64) -> index + %c0_i32 = arith.constant 0 : i32 + %c1_2 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %19 = arith.subi %18, %c1_2 : index + %20 = fir.do_loop %arg1 = %c0 to %19 step %c1_2 unordered iter_args(%arg2 = %17) -> (!fir.array<10xi32>) { + %22 = fir.array_update %arg2, %c0_i32, %arg1 : (!fir.array<10xi32>, i32, index) -> !fir.array<10xi32> + fir.result %22 : !fir.array<10xi32> + } + fir.array_merge_store %17, %20 to %12[%16] : !fir.array<10xi32>, !fir.array<10xi32>, !fir.ref>, !fir.slice<1> + %21 = arith.addi %arg0, %c1 : index + fir.result %21 : index + } + %5 = fir.convert %4 : (index) -> i32 + fir.store %5 to %0 : !fir.ref + return +} + +// CHECK-LABEL: func @array_of_types() { +// CHECK: %{{.*}} = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} -> index { +// CHECK: %{{.*}} = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered iter_args(%arg2 = %17) -> (!fir.array<10xi32>) { +// CHECK-NOT: %{{.*}} = fir.array_update +// CHECK: %[[COOR0:.*]] = fir.array_coor %{{.*}}(%{{.*}}) [%{{.*}}] %{{.*}} : (!fir.ref>, !fir.shape<1>, !fir.slice<1>, index) -> !fir.ref +// CHECK: fir.store %{{.*}} to %[[COOR0]] : !fir.ref +// CHECK-NOT: fir.array_merge_store + +// ----- + +// Test fir.array_load/boxed array +func @conversion_with_temporary_boxed_array(%arr0 : !fir.box>) { + %c10 = arith.constant 10 : index + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2 = fir.array_load %arr0(%1) : (!fir.box>, !fir.shape<1>) -> !fir.array<10xi32> + %c10_i64 = arith.constant 10 : i64 + %3 = fir.convert %c10_i64 : (i64) -> index + %c1_i64 = arith.constant 1 : i64 + %c-1_i64 = arith.constant -1 : i64 + %4 = fir.shape %c10 : (index) -> !fir.shape<1> + %5 = fir.slice %c10_i64, %c1_i64, %c-1_i64 : (i64, i64, i64) -> !fir.slice<1> + %6 = fir.array_load %arr0(%4) [%5] : (!fir.box>, !fir.shape<1>, !fir.slice<1>) -> !fir.array<10xi32> + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %7 = arith.subi %3, %c1 : index + %8 = fir.do_loop %arg0 = %c0 to %7 step %c1 unordered iter_args(%arg1 = %2) -> (!fir.array<10xi32>) { + %9 = fir.array_fetch %6, %arg0 : (!fir.array<10xi32>, index) -> i32 + %10 = fir.array_update %arg1, %9, %arg0 : (!fir.array<10xi32>, i32, index) -> !fir.array<10xi32> + fir.result %10 : !fir.array<10xi32> + } + fir.array_merge_store %2, %8 to %arr0 : !fir.array<10xi32>, !fir.array<10xi32>, !fir.box> + return +} + +// CHECK-LABEL: func @conversion_with_temporary_boxed_array( +// CHECK-SAME: %[[ARR0:.*]]: !fir.box>) +// Allocation of temporary array. +// CHECK: %[[TEMP:.*]] = fir.allocmem !fir.array<10xi32>, %{{.*}} +// Copy of original array to temp. +// CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: %[[COOR0:.*]] = fir.array_coor %[[ARR0]](%{{.*}}) %{{.*}} : (!fir.box>, !fir.shape<1>, index) -> !fir.ref +// CHECK: %[[LOAD0:.*]] = fir.load %[[COOR0]] : !fir.ref +// CHECK: %[[COOR1:.*]] = fir.array_coor %[[TEMP]](%{{.*}}) %{{.*}} : (!fir.heap>, !fir.shape<1>, index) -> !fir.ref +// CHECK: fir.store %[[LOAD0]] to %[[COOR1]] : !fir.ref +// CHECK: } +// Perform the assignment i = i(10:1:-1) using the temporary array. +// CHECK: %{{.*}} = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered iter_args(%{{.*}} = %{{.*}}) -> (!fir.array<10xi32>) { +// CHECK-NOT: %{{.*}} = fir.array_fetch +// CHECK-NOT: %{{.*}} = fir.update +// CHECK: %[[COOR0:.*]] = fir.array_coor %[[ARR0]](%{{.*}}) [%{{.*}}] %{{.*}} : (!fir.box>, !fir.shape<1>, !fir.slice<1>, index) -> !fir.ref +// CHECK: %[[LOAD0:.*]] = fir.load %[[COOR0]] : !fir.ref +// CHECK: %[[COOR1:.*]] = fir.array_coor %[[TEMP]](%{{.*}}) %{{.*}} : (!fir.heap>, !fir.shape<1>, index) -> !fir.ref +// CHECK: fir.store %[[LOAD0]] to %[[COOR1]] : !fir.ref +// CHECK: fir.result %{{.*}} : !fir.array<10xi32> +// CHECK: } +// Copy the result back to the original array. +// CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: %[[COOR0:.*]] = fir.array_coor %[[TEMP]](%{{.*}}) %{{.*}} : (!fir.heap>, !fir.shape<1>, index) -> !fir.ref +// CHECK: %[[LOAD0:.*]] = fir.load %[[COOR0:.*]] : !fir.ref +// CHECK: %[[COOR1:.*]] = fir.array_coor %[[ARR0]](%{{.*}}) %{{.*}} : (!fir.box>, !fir.shape<1>, index) -> !fir.ref +// CHECK: fir.store %[[LOAD0]] to %[[COOR1]] : !fir.ref +// CHECK: } +// Free temporary array. +// CHECK: fir.freemem %[[TEMP]] : !fir.heap> + +// ----- + +// Test simple fir.array_update with Fortran.offsets attribute. +func @array_update_conversion(%arr1 : !fir.box>, %m: index, %n: index) { + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + %c1 = arith.constant 1 : index + %f = arith.constant 2.0 : f32 + %s = fir.shape %m, %n : (index, index) -> !fir.shape<2> + %av1 = fir.array_load %arr1(%s) : (!fir.box>, !fir.shape<2>) -> !fir.array + %av2 = fir.array_update %av1, %f, %c1, %c1 {Fortran.offsets} : (!fir.array, f32, index, index) -> !fir.array + return +} + +// CHECK-LABEL: func @array_update_conversion +// CHECK-NOT: fir.array_update +// CHECK-NOT: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : index +// CHECK: %[[ARRAY_COOR:.*]] = fir.array_coor{{.*}}-> !fir.ref +// CHECK: fir.store %{{.*}} to %[[ARRAY_COOR]] : !fir.ref + +// ----- + +// Test fir.array_fetch on derived type members in an array of derived types. +func @array_fetch_derived_type(%0 : !fir.ref}>>>) { + %1 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QEi"} + %c1_i32 = arith.constant 1 : i32 + %2 = fir.convert %c1_i32 : (i32) -> index + %c10_i32 = arith.constant 10 : i32 + %3 = fir.convert %c10_i32 : (i32) -> index + %c1 = arith.constant 1 : index + %shape = fir.shape %2 : (index) -> !fir.shape<1> + %arr0 = fir.array_load %0(%shape) : (!fir.ref}>>>, !fir.shape<1>) -> !fir.array<10x!fir.type<_QTu{mt:!fir.type<_QTt{mem:i32}>}>> + %4 = fir.do_loop %arg0 = %2 to %3 step %c1 -> index { + %6 = fir.convert %arg0 : (index) -> i32 + fir.store %6 to %1 : !fir.ref + %c1_i32_0 = arith.constant 1 : i32 + %7 = fir.load %1 : !fir.ref + %8 = fir.convert %7 : (i32) -> i64 + %c1_i64 = arith.constant 1 : i64 + %9 = arith.subi %8, %c1_i64 : i64 + %11 = fir.field_index mt, !fir.type<_QTu{mt:!fir.type<_QTt{mem:i32}>}> + %12 = fir.field_index mem, !fir.type<_QTt{mem:i32}> + %idx = fir.convert %9 : (i64) -> index + %res = fir.array_fetch %arr0, %idx, %11, %12 : (!fir.array<10x!fir.type<_QTu{mt:!fir.type<_QTt{mem:i32}>}>>, index, !fir.field, !fir.field) -> i32 + %14 = arith.addi %arg0, %c1 : index + fir.result %14 : index + } + %5 = fir.convert %4 : (index) -> i32 + fir.store %5 to %1 : !fir.ref + return +} + +// CHECK-LABEL: func @array_fetch_derived_type( +// CHECK-SAME: %[[ARR0:.*]]: !fir.ref}>>>) { +// CHECK: %{{.*}} = fir.do_loop +// CHECK: %[[FIELD_MT:.*]] = fir.field_index mt, !fir.type<_QTu{mt:!fir.type<_QTt{mem:i32}>}> +// CHECK: %[[FIELD_MEM:.*]] = fir.field_index mem, !fir.type<_QTt{mem:i32}> +// CHECK-NOT: %{{.*}} = fir.array_fetch +// CHECK: %[[COOR0:.*]] = fir.array_coor %[[ARR0]](%{{.*}}) %{{.*}} : (!fir.ref}>>>, !fir.shape<1>, index) -> !fir.ref}>> +// CHECK: %[[COOR_OF:.*]] = fir.coordinate_of %[[COOR0]], %[[FIELD_MT]], %[[FIELD_MEM]] : (!fir.ref}>>, !fir.field, !fir.field) -> !fir.ref +// CHECK: %{{.*}} = fir.load %[[COOR_OF]] : !fir.ref + +// ----- + +// Test simple fir.array_load/fir.array_update conversion without copy-in/copy-out with a `fir.box` +func @array_update_conversion(%arr1 : !fir.box>, %m: index, %n: index) { + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + %c1 = arith.constant 1 : index + %f = arith.constant 2.0 : f32 + %s = fir.shape %m, %n : (index, index) -> !fir.shape<2> + %av1 = fir.array_load %arr1(%s) : (!fir.box>, !fir.shape<2>) -> !fir.array + %av2 = fir.array_update %av1, %f, %c1, %c1 : (!fir.array, f32, index, index) -> !fir.array + return +} + +// ----- + +// Test array operation with conditional update. + +func @array_operation_with_cond_update(%arg0: !fir.ref>, %cond1: i1) { + %c100 = arith.constant 100 : index + %c1 = arith.constant 1 : index + %c-1 = arith.constant -1 : index + %f = arith.constant 2.0 : f32 + %1 = fir.shape %c100 : (index) -> !fir.shape<1> + %2 = fir.array_load %arg0(%1) : (!fir.ref>, !fir.shape<1>) -> !fir.array<100xf32> + %arg2 = fir.if %cond1 -> !fir.array<100xf32> { + fir.result %2 : !fir.array<100xf32> + } else { + %r = fir.array_update %2, %f, %c1 : (!fir.array<100xf32>, f32, index) -> !fir.array<100xf32> + fir.result %r : !fir.array<100xf32> + } + fir.array_merge_store %2, %arg2 to %arg0 : !fir.array<100xf32>, !fir.array<100xf32>, !fir.ref> + return +} + +// CHECK-LABEL: func @array_operation_with_cond_update( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref>, %[[COND:.*]]: i1) { +// CHECK: %[[ARRAY_LOAD:.*]] = fir.undefined !fir.array<100xf32> +// CHECK: %[[IF_RES:.*]] = fir.if %[[COND]] -> (!fir.array<100xf32>) { +// CHECK: fir.result %[[ARRAY_LOAD]] : !fir.array<100xf32> +// CHECK: } else { +// CHECK: %[[UPDATE0:.*]] = fir.array_coor %[[ARG0]](%{{.*}}) %{{.*}} : (!fir.ref>, !fir.shape<1>, index) -> !fir.ref +// CHECK: fir.store %{{.*}} to %{{.*}} : !fir.ref +// CHECK: fir.result %[[ARRAY_LOAD]] : !fir.array<100xf32> +// CHECK: } diff --git a/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp b/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp --- a/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp +++ b/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp @@ -144,8 +144,8 @@ auto loc = builder.getUnknownLoc(); auto realTy = mlir::FloatType::getF64(ctx); auto cst = builder.createRealZeroConstant(loc, realTy); - EXPECT_TRUE(mlir::isa(cst.getDefiningOp())); - auto cstOp = dyn_cast(cst.getDefiningOp()); + EXPECT_TRUE(mlir::isa(cst.getDefiningOp())); + auto cstOp = dyn_cast(cst.getDefiningOp()); EXPECT_EQ(realTy, cstOp.getType()); EXPECT_EQ(0u, cstOp.value().cast().getValue().convertToDouble()); }