diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h --- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h +++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h @@ -21,6 +21,8 @@ #include "flang/Optimizer/Support/KindMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" namespace fir { class AbstractArrayBox; @@ -57,6 +59,13 @@ /// Get a reference to the kind map. const fir::KindMapping &getKindMap() { return kindMap; } + /// The LHS and RHS are not always in agreement in terms of + /// type. In some cases, the disagreement is between COMPLEX and other scalar + /// types. In that case, the conversion must insert/extract out of a COMPLEX + /// value to have the proper semantics and be strongly typed. + mlir::Value convertWithSemantics(mlir::Location loc, mlir::Type toTy, + mlir::Value val); + /// Get the entry block of the current Function mlir::Block *getEntryBlock() { return &getFunction().front(); } @@ -189,11 +198,6 @@ mlir::StringAttr createWeakLinkage() { return getStringAttr("weak"); } - /// Cast the input value to IndexType. - mlir::Value convertToIndexType(mlir::Location loc, mlir::Value val) { - return createConvert(loc, getIndexType(), val); - } - /// Get a function by name. If the function exists in the current module, it /// is returned. Otherwise, a null FuncOp is returned. mlir::FuncOp getNamedFunction(llvm::StringRef name) { @@ -243,6 +247,11 @@ return createFunction(loc, module, name, ty); } + /// Cast the input value to IndexType. + mlir::Value convertToIndexType(mlir::Location loc, mlir::Value val) { + return createConvert(loc, getIndexType(), val); + } + /// Construct one of the two forms of shape op from an array box. mlir::Value genShape(mlir::Location loc, const fir::AbstractArrayBox &arr); mlir::Value genShape(mlir::Location loc, llvm::ArrayRef shift, @@ -253,6 +262,18 @@ /// this may create a `fir.shift` op. mlir::Value createShape(mlir::Location loc, const fir::ExtendedValue &exv); + /// Create a slice op extended value. The value to be sliced, `exv`, must be + /// an array. + mlir::Value createSlice(mlir::Location loc, const fir::ExtendedValue &exv, + mlir::ValueRange triples, mlir::ValueRange path); + + /// Create a boxed value (Fortran descriptor) to be passed to the runtime. + /// \p exv is an extended value holding a memory reference to the object that + /// must be boxed. This function will crash if provided something that is not + /// a memory reference type. + /// Array entities are boxed with a shape and character with their length. + mlir::Value createBox(mlir::Location loc, const fir::ExtendedValue &exv); + /// Create constant i1 with value 1. if \p b is true or 0. otherwise mlir::Value createBool(mlir::Location loc, bool b) { return createIntegerConstant(loc, getIntegerType(1), b ? 1 : 0); @@ -326,6 +347,12 @@ /// Generate code testing \p addr is a null address. mlir::Value genIsNull(mlir::Location loc, mlir::Value addr); + /// Compute the extent of (lb:ub:step) as max((ub-lb+step)/step, 0). See + /// Fortran 2018 9.5.3.3.2 section for more details. + mlir::Value genExtentFromTriplet(mlir::Location loc, mlir::Value lb, + mlir::Value ub, mlir::Value step, + mlir::Type type); + private: const KindMapping &kindMap; }; @@ -345,6 +372,17 @@ mlir::Value readCharLen(fir::FirOpBuilder &builder, mlir::Location loc, const fir::ExtendedValue &box); +/// Read or get the extent in dimension \p dim of the array described by \p box. +mlir::Value readExtent(fir::FirOpBuilder &builder, mlir::Location loc, + const fir::ExtendedValue &box, unsigned dim); + +/// Read or get the lower bound in dimension \p dim of the array described by +/// \p box. If the lower bound is left default in the ExtendedValue, +/// \p defaultValue will be returned. +mlir::Value readLowerBound(fir::FirOpBuilder &builder, mlir::Location loc, + const fir::ExtendedValue &box, unsigned dim, + mlir::Value defaultValue); + /// Read extents from \p box. llvm::SmallVector readExtents(fir::FirOpBuilder &builder, mlir::Location loc, @@ -356,6 +394,14 @@ mlir::Location loc, const fir::ExtendedValue &box); +/// Read a fir::BoxValue into an fir::UnboxValue, a fir::ArrayBoxValue or a +/// fir::CharArrayBoxValue. This should only be called if the fir::BoxValue is +/// known to be contiguous given the context (or if the resulting address will +/// not be used). If the value is polymorphic, its dynamic type will be lost. +/// This must not be used on unlimited polymorphic and assumed rank entities. +fir::ExtendedValue readBoxValue(fir::FirOpBuilder &builder, mlir::Location loc, + const fir::BoxValue &box); + //===----------------------------------------------------------------------===// // String literal helper helpers //===----------------------------------------------------------------------===// @@ -365,10 +411,21 @@ fir::ExtendedValue createStringLiteral(fir::FirOpBuilder &, mlir::Location, llvm::StringRef string); +/// Create a !fir.char<1> string literal global and returns a +/// fir::CharBoxValue with its address en length. +fir::ExtendedValue createStringLiteral(fir::FirOpBuilder &, mlir::Location, + llvm::StringRef string); + /// Unique a compiler generated identifier. A short prefix should be provided /// to hint at the origin of the identifier. std::string uniqueCGIdent(llvm::StringRef prefix, llvm::StringRef name); +/// Lowers the extents from the sequence type to Values. +/// Any unknown extents are lowered to undefined values. +llvm::SmallVector createExtents(fir::FirOpBuilder &builder, + mlir::Location loc, + fir::SequenceType seqTy); + //===----------------------------------------------------------------------===// // Location helpers //===----------------------------------------------------------------------===// @@ -379,6 +436,45 @@ /// Generate a constant of the given type with the location line number mlir::Value locationToLineNo(fir::FirOpBuilder &, mlir::Location, mlir::Type); +//===----------------------------------------------------------------------===// +// ExtendedValue helpers +//===----------------------------------------------------------------------===// + +/// Return the extended value for a component of a derived type instance given +/// the extended value \p obj of the derived type instance and the address of +/// the component. +fir::ExtendedValue componentToExtendedValue(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::Value component); + +/// Given the address of an array element and the ExtendedValue describing the +/// array, returns the ExtendedValue describing the array element. The purpose +/// is to propagate the length parameters of the array to the element. +/// This can be used for elements of `array` or `array(i:j:k)`. If \p element +/// belongs to an array section `array%x` whose base is \p array, +/// arraySectionElementToExtendedValue must be used instead. +fir::ExtendedValue arrayElementToExtendedValue(fir::FirOpBuilder &builder, + mlir::Location loc, + const fir::ExtendedValue &array, + mlir::Value element); + +/// Build the ExtendedValue for \p element that is an element of an array or +/// array section with \p array base (`array` or `array(i:j:k)%x%y`). +/// If it is an array section, \p slice must be provided and be a fir::SliceOp +/// that describes the section. +fir::ExtendedValue arraySectionElementToExtendedValue( + fir::FirOpBuilder &builder, mlir::Location loc, + const fir::ExtendedValue &array, mlir::Value element, mlir::Value slice); + +/// Assign \p rhs to \p lhs. Both \p rhs and \p lhs must be scalar derived +/// types. The assignment follows Fortran intrinsic assignment semantic for +/// derived types (10.2.1.3 point 13). +void genRecordAssignment(fir::FirOpBuilder &builder, mlir::Location loc, + const fir::ExtendedValue &lhs, + const fir::ExtendedValue &rhs); + +mlir::TupleType getRaggedArrayHeaderType(fir::FirOpBuilder &builder); + } // namespace fir::factory #endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H diff --git a/flang/include/flang/Optimizer/Builder/Runtime/Assign.h b/flang/include/flang/Optimizer/Builder/Runtime/Assign.h new file mode 100644 --- /dev/null +++ b/flang/include/flang/Optimizer/Builder/Runtime/Assign.h @@ -0,0 +1,32 @@ +//===-- Assign.h - generate assignment runtime API calls ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_OPTIMIZER_BUILDER_RUNTIME_ASSIGN_H +#define FORTRAN_OPTIMIZER_BUILDER_RUNTIME_ASSIGN_H + +namespace mlir { +class Value; +class Location; +} // namespace mlir + +namespace fir { +class FirOpBuilder; +} // namespace fir + +namespace fir::runtime { + +/// Generate runtime call to assign \p sourceBox to \p destBox. +/// \p destBox must be a fir.ref> and \p sourceBox a fir.box. +/// \p destBox Fortran descriptor may be modified if destBox is an allocatable +/// according to Fortran allocatable assignment rules, otherwise it is not +/// modified. +void genAssign(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value destBox, mlir::Value sourceBox); + +} // namespace fir::runtime +#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_ASSIGN_H diff --git a/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h b/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h new file mode 100644 --- /dev/null +++ b/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h @@ -0,0 +1,407 @@ +//===-- RTBuilder.h ---------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines some C++17 template classes that are used to convert the +/// signatures of plain old C functions into a model that can be used to +/// generate MLIR calls to those functions. This can be used to autogenerate +/// tables at compiler compile-time to call runtime support code. +/// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RTBUILDER_H +#define FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RTBUILDER_H + +#include "flang/Common/Fortran.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "llvm/ADT/SmallVector.h" +#include + +// Incomplete type indicating C99 complex ABI in interfaces. Beware, _Complex +// and std::complex are layout compatible, but not compatible in all ABI call +// interface (e.g. X86 32 bits). _Complex is not standard C++, so do not use +// it here. +struct c_float_complex_t; +struct c_double_complex_t; + +namespace Fortran::runtime { +class Descriptor; +} // namespace Fortran::runtime + +namespace fir::runtime { + +using TypeBuilderFunc = mlir::Type (*)(mlir::MLIRContext *); +using FuncTypeBuilderFunc = mlir::FunctionType (*)(mlir::MLIRContext *); + +//===----------------------------------------------------------------------===// +// Type builder models +//===----------------------------------------------------------------------===// + +// TODO: all usages of sizeof in this file assume build == host == target. +// This will need to be re-visited for cross compilation. + +/// Return a function that returns the type signature model for the type `T` +/// when provided an MLIRContext*. This allows one to translate C(++) function +/// signatures from runtime header files to MLIR signatures into a static table +/// at compile-time. +/// +/// For example, when `T` is `int`, return a function that returns the MLIR +/// standard type `i32` when `sizeof(int)` is 4. +template +static constexpr TypeBuilderFunc getModel(); +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, 8 * sizeof(short int)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, 8 * sizeof(int)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + TypeBuilderFunc f{getModel()}; + return fir::ReferenceType::get(f(context)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return fir::ReferenceType::get(mlir::IntegerType::get(context, 8)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return getModel(); +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return fir::ReferenceType::get(mlir::IntegerType::get(context, 16)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return fir::ReferenceType::get(mlir::IntegerType::get(context, 32)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, 8 * sizeof(signed char)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return fir::LLVMPointerType::get(mlir::IntegerType::get(context, 8)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return fir::ReferenceType::get( + fir::LLVMPointerType::get(mlir::IntegerType::get(context, 8))); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, 8 * sizeof(long)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + TypeBuilderFunc f{getModel()}; + return fir::ReferenceType::get(f(context)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return getModel(); +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, 8 * sizeof(std::size_t)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + TypeBuilderFunc f{getModel()}; + return fir::ReferenceType::get(f(context)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return getModel(); +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, 8 * sizeof(unsigned long)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, 8 * sizeof(unsigned long long)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::FloatType::getF64(context); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + TypeBuilderFunc f{getModel()}; + return fir::ReferenceType::get(f(context)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return getModel(); +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::FloatType::getF32(context); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + TypeBuilderFunc f{getModel()}; + return fir::ReferenceType::get(f(context)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return getModel(); +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, 1); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + TypeBuilderFunc f{getModel()}; + return fir::ReferenceType::get(f(context)); + }; +} +template <> +constexpr TypeBuilderFunc getModel &>() { + return [](mlir::MLIRContext *context) -> mlir::Type { + auto ty = mlir::ComplexType::get(mlir::FloatType::getF32(context)); + return fir::ReferenceType::get(ty); + }; +} +template <> +constexpr TypeBuilderFunc getModel &>() { + return [](mlir::MLIRContext *context) -> mlir::Type { + auto ty = mlir::ComplexType::get(mlir::FloatType::getF64(context)); + return fir::ReferenceType::get(ty); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return fir::ComplexType::get(context, sizeof(float)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return fir::ComplexType::get(context, sizeof(double)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return fir::BoxType::get(mlir::NoneType::get(context)); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return fir::ReferenceType::get( + fir::BoxType::get(mlir::NoneType::get(context))); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return getModel(); +} +template <> +constexpr TypeBuilderFunc getModel() { + return getModel(); +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, + sizeof(Fortran::common::TypeCategory) * 8); + }; +} +template <> +constexpr TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::NoneType::get(context); + }; +} + +template +struct RuntimeTableKey; +template +struct RuntimeTableKey { + static constexpr FuncTypeBuilderFunc getTypeModel() { + return [](mlir::MLIRContext *ctxt) { + TypeBuilderFunc ret = getModel(); + std::array args = {getModel()...}; + mlir::Type retTy = ret(ctxt); + llvm::SmallVector argTys; + for (auto f : args) + argTys.push_back(f(ctxt)); + return mlir::FunctionType::get(ctxt, argTys, {retTy}); + }; + } +}; + +//===----------------------------------------------------------------------===// +// Runtime table building (constexpr folded) +//===----------------------------------------------------------------------===// + +template +using RuntimeIdentifier = std::integer_sequence; + +namespace details { +template +static constexpr std::integer_sequence +concat(std::integer_sequence, std::integer_sequence) { + return {}; +} +template +static constexpr auto concat(std::integer_sequence, + std::integer_sequence, Cs...) { + return concat(std::integer_sequence{}, Cs{}...); +} +template +static constexpr std::integer_sequence concat(std::integer_sequence) { + return {}; +} +template +static constexpr auto filterZero(std::integer_sequence) { + if constexpr (a != 0) { + return std::integer_sequence{}; + } else { + return std::integer_sequence{}; + } +} +template +static constexpr auto filter(std::integer_sequence) { + if constexpr (sizeof...(b) > 0) { + return details::concat(filterZero(std::integer_sequence{})...); + } else { + return std::integer_sequence{}; + } +} +} // namespace details + +template +struct RuntimeTableEntry; +template +struct RuntimeTableEntry, RuntimeIdentifier> { + static constexpr FuncTypeBuilderFunc getTypeModel() { + return RuntimeTableKey::getTypeModel(); + } + static constexpr const char name[sizeof...(Cs) + 1] = {Cs..., '\0'}; +}; + +#undef E +#define E(L, I) (I < sizeof(L) / sizeof(*L) ? L[I] : 0) +#define QuoteKey(X) #X +#define ExpandAndQuoteKey(X) QuoteKey(X) +#define MacroExpandKey(X) \ + E(X, 0), E(X, 1), E(X, 2), E(X, 3), E(X, 4), E(X, 5), E(X, 6), E(X, 7), \ + E(X, 8), E(X, 9), E(X, 10), E(X, 11), E(X, 12), E(X, 13), E(X, 14), \ + E(X, 15), E(X, 16), E(X, 17), E(X, 18), E(X, 19), E(X, 20), E(X, 21), \ + E(X, 22), E(X, 23), E(X, 24), E(X, 25), E(X, 26), E(X, 27), E(X, 28), \ + E(X, 29), E(X, 30), E(X, 31), E(X, 32), E(X, 33), E(X, 34), E(X, 35), \ + E(X, 36), E(X, 37), E(X, 38), E(X, 39), E(X, 40), E(X, 41), E(X, 42), \ + E(X, 43), E(X, 44), E(X, 45), E(X, 46), E(X, 47), E(X, 48), E(X, 49) +#define ExpandKey(X) MacroExpandKey(QuoteKey(X)) +#define FullSeq(X) std::integer_sequence +#define AsSequence(X) decltype(fir::runtime::details::filter(FullSeq(X){})) +#define mkKey(X) \ + fir::runtime::RuntimeTableEntry, \ + AsSequence(X)> +#define mkRTKey(X) mkKey(RTNAME(X)) + +/// Get (or generate) the MLIR FuncOp for a given runtime function. Its template +/// argument is intended to be of the form: +/// Clients should add "using namespace Fortran::runtime" +/// in order to use this function. +template +static mlir::FuncOp getRuntimeFunc(mlir::Location loc, + fir::FirOpBuilder &builder) { + auto name = RuntimeEntry::name; + auto func = builder.getNamedFunction(name); + if (func) + return func; + auto funTy = RuntimeEntry::getTypeModel()(builder.getContext()); + func = builder.createFunction(loc, name, funTy); + func->setAttr("fir.runtime", builder.getUnitAttr()); + return func; +} + +namespace helper { +template +void createArguments(llvm::SmallVectorImpl &result, + fir::FirOpBuilder &builder, mlir::Location loc, + mlir::FunctionType fTy, A arg) { + result.emplace_back(builder.createConvert(loc, fTy.getInput(N), arg)); +} + +template +void createArguments(llvm::SmallVectorImpl &result, + fir::FirOpBuilder &builder, mlir::Location loc, + mlir::FunctionType fTy, A arg, As... args) { + result.emplace_back(builder.createConvert(loc, fTy.getInput(N), arg)); + createArguments(result, builder, loc, fTy, args...); +} +} // namespace helper + +/// Create a SmallVector of arguments for a runtime call. +template +llvm::SmallVector +createArguments(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::FunctionType fTy, As... args) { + llvm::SmallVector result; + helper::createArguments<0>(result, builder, loc, fTy, args...); + return result; +} + +} // namespace fir::runtime + +#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_RTBUILDER_H diff --git a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h --- a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h +++ b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h @@ -61,6 +61,7 @@ /// Attribute to mark Fortran entities with the CONTIGUOUS attribute. constexpr llvm::StringRef getContiguousAttrName() { return "fir.contiguous"; } + /// Attribute to mark Fortran entities with the OPTIONAL attribute. constexpr llvm::StringRef getOptionalAttrName() { return "fir.optional"; } diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td --- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td +++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td @@ -224,6 +224,28 @@ }]; } +def fir_LLVMPointerType : FIR_Type<"LLVMPointer", "llvm_ptr"> { + let summary = "Like LLVM pointer type"; + + let description = [{ + A pointer type that does not have any of the constraints and semantics + of other FIR pointer types and that translates to llvm pointer types. + It is meant to implement indirection that cannot be expressed directly + in Fortran, but are needed to implement some Fortran features (e.g, + double indirections). + }]; + + let parameters = (ins "mlir::Type":$eleTy); + + let skipDefaultBuilders = 1; + + let builders = [ + TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{ + return Base::get(elementType.getContext(), elementType); + }]>, + ]; +} + def fir_PointerType : FIR_Type<"Pointer", "ptr"> { let summary = "Reference to a POINTER attribute type"; @@ -401,6 +423,12 @@ "mlir::Type":$eleTy), [{ return get(eleTy.getContext(), shape, eleTy, {}); }]>, + TypeBuilderWithInferredContext<(ins + "mlir::Type":$eleTy, + "size_t":$dimensions), [{ + llvm::SmallVector shape(dimensions, getUnknownExtent()); + return get(eleTy.getContext(), shape, eleTy, {}); + }]> ]; let extraClassDeclaration = [{ diff --git a/flang/include/flang/Optimizer/Transforms/Factory.h b/flang/include/flang/Optimizer/Transforms/Factory.h new file mode 100644 --- /dev/null +++ b/flang/include/flang/Optimizer/Transforms/Factory.h @@ -0,0 +1,255 @@ +//===-- Optimizer/Transforms/Factory.h --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Templates to generate more complex code patterns in transformation passes. +// In transformation passes, front-end information such as is available in +// lowering is not available. +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_OPTIMIZER_TRANSFORMS_FACTORY_H +#define FORTRAN_OPTIMIZER_TRANSFORMS_FACTORY_H + +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "llvm/ADT/iterator_range.h" + +namespace mlir { +class Location; +class Value; +} // namespace mlir + +namespace fir::factory { + +constexpr llvm::StringRef attrFortranArrayOffsets() { + return "Fortran.offsets"; +} + +/// Generate a character copy with optimized forms. +/// +/// If the lengths are constant and equal, use load/store rather than a loop. +/// Otherwise, if the lengths are constant and the input is longer than the +/// output, generate a loop to move a truncated portion of the source to the +/// destination. Finally, if the lengths are runtime values or the destination +/// is longer than the source, move the entire source character and pad the +/// destination with spaces as needed. +template +void genCharacterCopy(mlir::Value src, mlir::Value srcLen, mlir::Value dst, + mlir::Value dstLen, B &builder, mlir::Location loc) { + auto srcTy = + fir::dyn_cast_ptrEleTy(src.getType()).template cast(); + auto dstTy = + fir::dyn_cast_ptrEleTy(dst.getType()).template cast(); + if (!srcLen && !dstLen && srcTy.getFKind() == dstTy.getFKind() && + srcTy.getLen() == dstTy.getLen()) { + // same size, so just use load and store + auto load = builder.template create(loc, src); + builder.template create(loc, load, dst); + return; + } + auto zero = builder.template create(loc, 0); + auto one = builder.template create(loc, 1); + auto toArrayTy = [&](fir::CharacterType ty) { + return fir::ReferenceType::get(fir::SequenceType::get( + fir::SequenceType::ShapeRef{fir::SequenceType::getUnknownExtent()}, + fir::CharacterType::getSingleton(ty.getContext(), ty.getFKind()))); + }; + auto toEleTy = [&](fir::ReferenceType ty) { + auto seqTy = ty.getEleTy().cast(); + return seqTy.getEleTy().cast(); + }; + auto toCoorTy = [&](fir::ReferenceType ty) { + return fir::ReferenceType::get(toEleTy(ty)); + }; + if (!srcLen && !dstLen && srcTy.getLen() >= dstTy.getLen()) { + auto upper = builder.template create( + loc, dstTy.getLen() - 1); + auto loop = builder.template create(loc, zero, upper, one); + auto insPt = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(loop.getBody()); + auto csrcTy = toArrayTy(srcTy); + auto csrc = builder.template create(loc, csrcTy, src); + auto in = builder.template create( + loc, toCoorTy(csrcTy), csrc, loop.getInductionVar()); + auto load = builder.template create(loc, in); + auto cdstTy = toArrayTy(dstTy); + auto cdst = builder.template create(loc, cdstTy, dst); + auto out = builder.template create( + loc, toCoorTy(cdstTy), cdst, loop.getInductionVar()); + mlir::Value cast = + srcTy.getFKind() == dstTy.getFKind() + ? load.getResult() + : builder + .template create(loc, toEleTy(cdstTy), load) + .getResult(); + builder.template create(loc, cast, out); + builder.restoreInsertionPoint(insPt); + return; + } + auto minusOne = [&](mlir::Value v) -> mlir::Value { + return builder.template create( + loc, builder.template create(loc, one.getType(), v), + one); + }; + mlir::Value len = dstLen ? minusOne(dstLen) + : builder + .template create( + loc, dstTy.getLen() - 1) + .getResult(); + auto loop = builder.template create(loc, zero, len, one); + auto insPt = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(loop.getBody()); + mlir::Value slen = + srcLen + ? builder.template create(loc, one.getType(), srcLen) + .getResult() + : builder.template create(loc, srcTy.getLen()) + .getResult(); + auto cond = builder.template create( + loc, arith::CmpIPredicate::slt, loop.getInductionVar(), slen); + auto ifOp = builder.template create(loc, cond, /*withElse=*/true); + builder.setInsertionPointToStart(&ifOp.thenRegion().front()); + auto csrcTy = toArrayTy(srcTy); + auto csrc = builder.template create(loc, csrcTy, src); + auto in = builder.template create( + loc, toCoorTy(csrcTy), csrc, loop.getInductionVar()); + auto load = builder.template create(loc, in); + auto cdstTy = toArrayTy(dstTy); + auto cdst = builder.template create(loc, cdstTy, dst); + auto out = builder.template create( + loc, toCoorTy(cdstTy), cdst, loop.getInductionVar()); + mlir::Value cast = + srcTy.getFKind() == dstTy.getFKind() + ? load.getResult() + : builder.template create(loc, toEleTy(cdstTy), load) + .getResult(); + builder.template create(loc, cast, out); + builder.setInsertionPointToStart(&ifOp.elseRegion().front()); + auto space = builder.template create( + loc, toEleTy(cdstTy), llvm::ArrayRef{' '}); + auto cdst2 = builder.template create(loc, cdstTy, dst); + auto out2 = builder.template create( + loc, toCoorTy(cdstTy), cdst2, loop.getInductionVar()); + builder.template create(loc, space, out2); + builder.restoreInsertionPoint(insPt); +} + +/// Get extents from fir.shape/fir.shape_shift op. Empty result if +/// \p shapeVal is empty or is a fir.shift. +inline std::vector getExtents(mlir::Value shapeVal) { + if (shapeVal) + if (auto *shapeOp = shapeVal.getDefiningOp()) { + if (auto shOp = mlir::dyn_cast(shapeOp)) + return shOp.getExtents(); + if (auto shOp = mlir::dyn_cast(shapeOp)) + return shOp.getExtents(); + } + return {}; +} + +/// Get origins from fir.shape_shift/fir.shift op. Empty result if +/// \p shapeVal is empty or is a fir.shape. +inline std::vector getOrigins(mlir::Value shapeVal) { + if (shapeVal) + if (auto *shapeOp = shapeVal.getDefiningOp()) { + if (auto shOp = mlir::dyn_cast(shapeOp)) + return shOp.getOrigins(); + if (auto shOp = mlir::dyn_cast(shapeOp)) + return shOp.getOrigins(); + } + return {}; +} + +/// Convert the normalized indices on array_fetch and array_update to the +/// dynamic (and non-zero) origin required by array_coor. +/// Do not adjust any trailing components in the path as they specify a +/// particular path into the array value and must already correspond to the +/// structure of an element. +template +llvm::SmallVector +originateIndices(mlir::Location loc, B &builder, mlir::Type memTy, + mlir::Value shapeVal, mlir::ValueRange indices) { + llvm::SmallVector result; + auto origins = getOrigins(shapeVal); + if (origins.empty()) { + assert(!shapeVal || mlir::isa(shapeVal.getDefiningOp())); + auto ty = fir::dyn_cast_ptrOrBoxEleTy(memTy); + assert(ty && ty.isa()); + auto seqTy = ty.cast(); + auto one = builder.template create(loc, 1); + const auto dimension = seqTy.getDimension(); + if (shapeVal) { + assert(dimension == mlir::cast(shapeVal.getDefiningOp()) + .getType() + .getRank()); + } + for (auto i : llvm::enumerate(indices)) { + if (i.index() < dimension) { + assert(fir::isa_integer(i.value().getType())); + result.push_back( + builder.template create(loc, i.value(), one)); + } else { + result.push_back(i.value()); + } + } + return result; + } + const auto dimension = origins.size(); + unsigned origOff = 0; + for (auto i : llvm::enumerate(indices)) { + if (i.index() < dimension) + result.push_back(builder.template create( + loc, i.value(), origins[origOff++])); + else + result.push_back(i.value()); + } + return result; +} + +template +llvm::SmallVector createLoopNest( + mlir::Location loc, B &builder, llvm::iterator_range lows, + llvm::iterator_range highs, llvm::iterator_range steps, + llvm::ArrayRef threadedVals, bool unordered = false) { + llvm::SmallVector loops; + llvm::SmallVector inners(threadedVals.begin(), + threadedVals.end()); + for (auto iter0 = lows.begin(), iter1 = highs.begin(), iter2 = steps.begin(); + iter1 != highs.end(); ++iter0, ++iter1, ++iter2) { + auto lp = builder.template create( + loc, *iter0, *iter1, *iter2, unordered, + /*finalCount=*/false, inners); + loops.push_back(lp); + inners.assign(lp.getRegionIterArgs().begin(), lp.getRegionIterArgs().end()); + builder.setInsertionPointToStart(lp.getBody()); + } + auto numLoops = loops.size(); + for (decltype(numLoops) i = 0; i + 1 < numLoops; ++i) { + builder.setInsertionPointToEnd(loops[i].getBody()); + builder.template create(loc, loops[i + 1].getResults()); + } + builder.setInsertionPointAfter(loops[0]); + llvm::errs() << loops[0] << '\n'; + return loops; +} + +template +llvm::SmallVector createLoopNest( + mlir::Location loc, B &builder, llvm::ArrayRef lows, + llvm::ArrayRef highs, llvm::ArrayRef steps, + llvm::ArrayRef threadedVals, bool unordered = false) { + return createLoopNest( + loc, builder, llvm::make_range(lows.begin(), lows.end()), + llvm::make_range(highs.begin(), highs.end()), + llvm::make_range(steps.begin(), steps.end()), threadedVals, unordered); +} + +} // namespace fir::factory + +#endif // FORTRAN_OPTIMIZER_TRANSFORMS_FACTORY_H diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h --- a/flang/include/flang/Optimizer/Transforms/Passes.h +++ b/flang/include/flang/Optimizer/Transforms/Passes.h @@ -28,6 +28,7 @@ std::unique_ptr createAbstractResultOptPass(); std::unique_ptr createAffineDemotionPass(); +std::unique_ptr createArrayValueCopyPass(); std::unique_ptr createFirToCfgPass(); std::unique_ptr createCharacterConversionPass(); std::unique_ptr createExternalNameConversionPass(); diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -74,6 +74,28 @@ ]; } +def ArrayValueCopy : FunctionPass<"array-value-copy"> { + let summary = "Convert array value operations to memory operations."; + let description = [{ + Transform the set of array value primitives to a memory-based array + representation. + + The Ops `array_load`, `array_store`, `array_fetch`, and `array_update` are + used to manage abstract aggregate array values. A simple analysis is done + to determine if there are potential dependences between these operations. + If not, these array operations can be lowered to work directly on the memory + representation. If there is a potential conflict, a temporary is created + along with appropriate copy-in/copy-out operations. Here, a more refined + analysis might be deployed, such as using the affine framework. + + This pass is required before code gen to the LLVM IR dialect. + }]; + let constructor = "::fir::createArrayValueCopyPass()"; + let dependentDialects = [ + "fir::FIROpsDialect", "mlir::StandardOpsDialect" + ]; +} + def CharacterConversion : Pass<"character-conversion"> { let summary = "Convert CHARACTER entities with different KINDs"; let description = [{ diff --git a/flang/lib/Optimizer/Builder/CMakeLists.txt b/flang/lib/Optimizer/Builder/CMakeLists.txt --- a/flang/lib/Optimizer/Builder/CMakeLists.txt +++ b/flang/lib/Optimizer/Builder/CMakeLists.txt @@ -6,6 +6,7 @@ DoLoopHelper.cpp FIRBuilder.cpp MutableBox.cpp + Runtime/Assign.cpp DEPENDS FIRDialect diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -7,15 +7,20 @@ //===----------------------------------------------------------------------===// #include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Lower/Todo.h" #include "flang/Optimizer/Builder/BoxValue.h" #include "flang/Optimizer/Builder/Character.h" #include "flang/Optimizer/Builder/MutableBox.h" +#include "flang/Optimizer/Builder/Runtime/Assign.h" +#include "flang/Optimizer/Dialect/FIRAttr.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "flang/Optimizer/Support/FatalError.h" #include "flang/Optimizer/Support/InternalNames.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MD5.h" static llvm::cl::opt @@ -110,7 +115,7 @@ const llvm::APFloat &value) { if (fltTy.isa()) { auto attr = getFloatAttr(fltTy, value); - return create(loc, fltTy, attr); + return create(loc, fltTy, attr); } llvm_unreachable("should use builtin floating-point type"); } @@ -257,6 +262,34 @@ return glob; } +mlir::Value fir::FirOpBuilder::convertWithSemantics(mlir::Location loc, + mlir::Type toTy, + mlir::Value val) { + assert(toTy && "store location must be typed"); + auto fromTy = val.getType(); + if (fromTy == toTy) + return val; + // fir::factory::ComplexExprHelper helper{*this, loc}; + // if ((fir::isa_real(fromTy) || fir::isa_integer(fromTy)) && + // fir::isa_complex(toTy)) { + // // imaginary part is zero + // auto eleTy = helper.getComplexPartType(toTy); + // auto cast = createConvert(loc, eleTy, val); + // llvm::APFloat zero{ + // kindMap.getFloatSemantics(toTy.cast().getFKind()), + // 0}; + // auto imag = createRealConstant(loc, eleTy, zero); + // return helper.createComplex(toTy, cast, imag); + // } + // if (fir::isa_complex(fromTy) && + // (fir::isa_integer(toTy) || fir::isa_real(toTy))) { + // // drop the imaginary part + // auto rp = helper.extractComplexPart(val, /*isImagPart=*/false); + // return createConvert(loc, toTy, rp); + // } + return createConvert(loc, toTy, val); +} + mlir::Value fir::FirOpBuilder::createConvert(mlir::Location loc, mlir::Type toTy, mlir::Value val) { if (val.getType() != toTy) { @@ -327,23 +360,135 @@ [&](auto) -> mlir::Value { fir::emitFatalError(loc, "not an array"); }); } -static mlir::Value genNullPointerComparison(fir::FirOpBuilder &builder, - mlir::Location loc, - mlir::Value addr, - arith::CmpIPredicate condition) { +mlir::Value fir::FirOpBuilder::createSlice(mlir::Location loc, + const fir::ExtendedValue &exv, + mlir::ValueRange triples, + mlir::ValueRange path) { + if (triples.empty()) { + // If there is no slicing by triple notation, then take the whole array. + auto fullShape = [&](const llvm::ArrayRef lbounds, + llvm::ArrayRef extents) -> mlir::Value { + llvm::SmallVector trips; + auto idxTy = getIndexType(); + auto one = createIntegerConstant(loc, idxTy, 1); + auto sliceTy = fir::SliceType::get(getContext(), extents.size()); + if (lbounds.empty()) { + for (auto v : extents) { + trips.push_back(one); + trips.push_back(v); + trips.push_back(one); + } + return create(loc, sliceTy, trips, path); + } + for (auto [lbnd, ext] : llvm::zip(lbounds, extents)) { + auto lb = createConvert(loc, idxTy, lbnd); + trips.push_back(lb); + trips.push_back(ext); + trips.push_back(one); + } + return create(loc, sliceTy, trips, path); + }; + return exv.match( + [&](const fir::ArrayBoxValue &box) { + return fullShape(box.getLBounds(), box.getExtents()); + }, + [&](const fir::CharArrayBoxValue &box) { + return fullShape(box.getLBounds(), box.getExtents()); + }, + [&](const fir::BoxValue &box) { + auto extents = fir::factory::readExtents(*this, loc, box); + return fullShape(box.getLBounds(), extents); + }, + [&](const fir::MutableBoxValue &) -> mlir::Value { + // MutableBoxValue must be read into another category to work with + // them outside of allocation/assignment contexts. + fir::emitFatalError(loc, "createSlice on MutableBoxValue"); + }, + [&](auto) -> mlir::Value { fir::emitFatalError(loc, "not an array"); }); + } + auto rank = exv.rank(); + auto sliceTy = fir::SliceType::get(getContext(), rank); + return create(loc, sliceTy, triples, path); +} + +mlir::Value fir::FirOpBuilder::createBox(mlir::Location loc, + const fir::ExtendedValue &exv) { + auto itemAddr = fir::getBase(exv); + if (itemAddr.getType().isa()) + return itemAddr; + auto elementType = fir::dyn_cast_ptrEleTy(itemAddr.getType()); + if (!elementType) + mlir::emitError(loc, "internal: expected a memory reference type ") + << itemAddr.getType(); + auto boxTy = fir::BoxType::get(elementType); + return exv.match( + [&](const fir::ArrayBoxValue &box) -> mlir::Value { + auto s = createShape(loc, exv); + return create(loc, boxTy, itemAddr, s); + }, + [&](const fir::CharArrayBoxValue &box) -> mlir::Value { + auto s = createShape(loc, exv); + if (fir::factory::CharacterExprHelper::hasConstantLengthInType(exv)) + return create(loc, boxTy, itemAddr, s); + + mlir::Value emptySlice; + llvm::SmallVector lenParams{box.getLen()}; + return create(loc, boxTy, itemAddr, s, emptySlice, + lenParams); + }, + [&](const fir::CharBoxValue &box) -> mlir::Value { + if (fir::factory::CharacterExprHelper::hasConstantLengthInType(exv)) + return create(loc, boxTy, itemAddr); + mlir::Value emptyShape, emptySlice; + llvm::SmallVector lenParams{box.getLen()}; + return create(loc, boxTy, itemAddr, emptyShape, + emptySlice, lenParams); + }, + [&](const fir::MutableBoxValue &x) -> mlir::Value { + return create( + loc, fir::factory::getMutableIRBox(*this, loc, x)); + }, + [&](const auto &) -> mlir::Value { + return create(loc, boxTy, itemAddr); + }); +} + +static mlir::Value +genNullPointerComparison(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value addr, + mlir::arith::CmpIPredicate condition) { auto intPtrTy = builder.getIntPtrType(); auto ptrToInt = builder.createConvert(loc, intPtrTy, addr); auto c0 = builder.createIntegerConstant(loc, intPtrTy, 0); - return builder.create(loc, condition, ptrToInt, c0); + return builder.create(loc, condition, ptrToInt, c0); } mlir::Value fir::FirOpBuilder::genIsNotNull(mlir::Location loc, mlir::Value addr) { - return genNullPointerComparison(*this, loc, addr, arith::CmpIPredicate::ne); + return genNullPointerComparison(*this, loc, addr, + mlir::arith::CmpIPredicate::ne); } mlir::Value fir::FirOpBuilder::genIsNull(mlir::Location loc, mlir::Value addr) { - return genNullPointerComparison(*this, loc, addr, arith::CmpIPredicate::eq); + return genNullPointerComparison(*this, loc, addr, + mlir::arith::CmpIPredicate::eq); +} + +mlir::Value fir::FirOpBuilder::genExtentFromTriplet(mlir::Location loc, + mlir::Value lb, + mlir::Value ub, + mlir::Value step, + mlir::Type type) { + auto zero = createIntegerConstant(loc, type, 0); + lb = createConvert(loc, type, lb); + ub = createConvert(loc, type, ub); + step = createConvert(loc, type, step); + auto diff = create(loc, ub, lb); + auto add = create(loc, diff, step); + auto div = create(loc, add, step); + auto cmp = create(loc, mlir::arith::CmpIPredicate::sgt, + div, zero); + return create(loc, cmp, div, zero); } //===--------------------------------------------------------------------===// @@ -376,6 +521,67 @@ }); } +mlir::Value fir::factory::readExtent(fir::FirOpBuilder &builder, + mlir::Location loc, + const fir::ExtendedValue &box, + unsigned dim) { + assert(box.rank() > dim); + return box.match( + [&](const fir::ArrayBoxValue &x) -> mlir::Value { + return x.getExtents()[dim]; + }, + [&](const fir::CharArrayBoxValue &x) -> mlir::Value { + return x.getExtents()[dim]; + }, + [&](const fir::BoxValue &x) -> mlir::Value { + if (!x.getExplicitExtents().empty()) + return x.getExplicitExtents()[dim]; + auto idxTy = builder.getIndexType(); + auto dimVal = builder.createIntegerConstant(loc, idxTy, dim); + return builder + .create(loc, idxTy, idxTy, idxTy, x.getAddr(), + dimVal) + .getResult(1); + }, + [&](const fir::MutableBoxValue &x) -> mlir::Value { + // MutableBoxValue must be read into another category to work with them + // outside of allocation/assignment contexts. + fir::emitFatalError(loc, "readExtents on MutableBoxValue"); + }, + [&](const auto &) -> mlir::Value { + fir::emitFatalError(loc, "extent inquiry on scalar"); + }); +} + +mlir::Value fir::factory::readLowerBound(fir::FirOpBuilder &, + mlir::Location loc, + const fir::ExtendedValue &box, + unsigned dim, + mlir::Value defaultValue) { + assert(box.rank() > dim); + auto lb = box.match( + [&](const fir::ArrayBoxValue &x) -> mlir::Value { + return x.getLBounds().empty() ? mlir::Value{} : x.getLBounds()[dim]; + }, + [&](const fir::CharArrayBoxValue &x) -> mlir::Value { + return x.getLBounds().empty() ? mlir::Value{} : x.getLBounds()[dim]; + }, + [&](const fir::BoxValue &x) -> mlir::Value { + return x.getLBounds().empty() ? mlir::Value{} : x.getLBounds()[dim]; + }, + [&](const fir::MutableBoxValue &) -> mlir::Value { + // MutableBoxValue must be read into another category to work with them + // outside of allocation/assignment contexts. + fir::emitFatalError(loc, "readLowerBound on MutableBoxValue"); + }, + [&](const auto &) -> mlir::Value { + fir::emitFatalError(loc, "lower bound inquiry on scalar"); + }); + if (lb) + return lb; + return defaultValue; +} + llvm::SmallVector fir::factory::readExtents(fir::FirOpBuilder &builder, mlir::Location loc, const fir::BoxValue &box) { @@ -416,6 +622,29 @@ [&](const auto &) -> llvm::SmallVector { return {}; }); } +fir::ExtendedValue fir::factory::readBoxValue(fir::FirOpBuilder &builder, + mlir::Location loc, + const fir::BoxValue &box) { + assert(!box.isUnlimitedPolymorphic() && !box.hasAssumedRank() && + "cannot read unlimited polymorphic or assumed rank fir.box"); + auto addr = + builder.create(loc, box.getMemTy(), box.getAddr()); + if (box.isCharacter()) { + auto len = fir::factory::readCharLen(builder, loc, box); + if (box.rank() == 0) + return fir::CharBoxValue(addr, len); + return fir::CharArrayBoxValue(addr, len, + fir::factory::readExtents(builder, loc, box), + box.getLBounds()); + } + if (box.isDerivedWithLengthParameters()) + TODO(loc, "read fir.box with length parameters"); + if (box.rank() == 0) + return addr; + return fir::ArrayBoxValue(addr, fir::factory::readExtents(builder, loc, box), + box.getLBounds()); +} + std::string fir::factory::uniqueCGIdent(llvm::StringRef prefix, llvm::StringRef name) { // For "long" identifiers use a hash value @@ -474,3 +703,186 @@ loc, builder.getCharacterLengthType(), str.size()); return fir::CharBoxValue{addr, len}; } + +llvm::SmallVector +fir::factory::createExtents(fir::FirOpBuilder &builder, mlir::Location loc, + fir::SequenceType seqTy) { + llvm::SmallVector extents; + auto idxTy = builder.getIndexType(); + for (auto ext : seqTy.getShape()) + extents.emplace_back( + ext == fir::SequenceType::getUnknownExtent() + ? builder.create(loc, idxTy).getResult() + : builder.createIntegerConstant(loc, idxTy, ext)); + return extents; +} + +// FIXME: This needs some work. To correctly determine the extended value of a +// component, one needs the base object, its type, and its type parameters. (An +// alternative would be to provide an already computed address of the final +// component rather than the base object's address, the point being the result +// will require the address of the final component to create the extended +// value.) One further needs the full path of components being applied. One +// needs to apply type-based expressions to type parameters along this said +// path. (See applyPathToType for a type-only derivation.) Finally, one needs to +// compose the extended value of the terminal component, including all of its +// parameters: array lower bounds expressions, extents, type parameters, etc. +// Any of these properties may be deferred until runtime in Fortran. This +// operation may therefore generate a sizeable block of IR, including calls to +// type-based helper functions, so caching the result of this operation in the +// client would be advised as well. +fir::ExtendedValue fir::factory::componentToExtendedValue( + fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value component) { + auto fieldTy = component.getType(); + if (auto ty = fir::dyn_cast_ptrEleTy(fieldTy)) + fieldTy = ty; + if (fieldTy.isa()) { + llvm::SmallVector nonDeferredTypeParams; + auto eleTy = fir::unwrapSequenceType(fir::dyn_cast_ptrOrBoxEleTy(fieldTy)); + if (auto charTy = eleTy.dyn_cast()) { + auto lenTy = builder.getCharacterLengthType(); + if (charTy.hasConstantLen()) + nonDeferredTypeParams.emplace_back( + builder.createIntegerConstant(loc, lenTy, charTy.getLen())); + // TODO: Starting, F2003, the dynamic character length might be dependent + // on a PDT length parameter. There is no way to make a difference with + // deferred length here yet. + } + if (auto recTy = eleTy.dyn_cast()) + if (recTy.getNumLenParams() > 0) + TODO(loc, "allocatable and pointer components non deferred length " + "parameters"); + + return fir::MutableBoxValue(component, nonDeferredTypeParams, + /*mutableProperties=*/{}); + } + llvm::SmallVector extents; + if (auto seqTy = fieldTy.dyn_cast()) { + fieldTy = seqTy.getEleTy(); + auto idxTy = builder.getIndexType(); + for (auto extent : seqTy.getShape()) { + if (extent == fir::SequenceType::getUnknownExtent()) + TODO(loc, "array component shape depending on length parameters"); + extents.emplace_back(builder.createIntegerConstant(loc, idxTy, extent)); + } + } + if (auto charTy = fieldTy.dyn_cast()) { + auto cstLen = charTy.getLen(); + if (cstLen == fir::CharacterType::unknownLen()) + TODO(loc, "get character component length from length type parameters"); + auto len = builder.createIntegerConstant( + loc, builder.getCharacterLengthType(), cstLen); + if (!extents.empty()) + return fir::CharArrayBoxValue{component, len, extents}; + return fir::CharBoxValue{component, len}; + } + if (auto recordTy = fieldTy.dyn_cast()) + if (recordTy.getNumLenParams() != 0) + TODO(loc, + "lower component ref that is a derived type with length parameter"); + if (!extents.empty()) + return fir::ArrayBoxValue{component, extents}; + return component; +} + +fir::ExtendedValue fir::factory::arrayElementToExtendedValue( + fir::FirOpBuilder &builder, mlir::Location loc, + const fir::ExtendedValue &array, mlir::Value element) { + return array.match( + [&](const fir::CharBoxValue &cb) -> fir::ExtendedValue { + return cb.clone(element); + }, + [&](const fir::CharArrayBoxValue &bv) -> fir::ExtendedValue { + return bv.cloneElement(element); + }, + [&](const fir::BoxValue &box) -> fir::ExtendedValue { + if (box.isCharacter()) { + auto len = fir::factory::readCharLen(builder, loc, box); + return fir::CharBoxValue{element, len}; + } + if (box.isDerivedWithLengthParameters()) + TODO(loc, "get length parameters from derived type BoxValue"); + return element; + }, + [&](const auto &) -> fir::ExtendedValue { return element; }); +} + +fir::ExtendedValue fir::factory::arraySectionElementToExtendedValue( + fir::FirOpBuilder &builder, mlir::Location loc, + const fir::ExtendedValue &array, mlir::Value element, mlir::Value slice) { + if (!slice) + return arrayElementToExtendedValue(builder, loc, array, element); + auto sliceOp = mlir::dyn_cast_or_null(slice.getDefiningOp()); + assert(sliceOp && "slice must be a sliceOp"); + if (sliceOp.fields().empty()) + return arrayElementToExtendedValue(builder, loc, array, element); + // For F95, using componentToExtendedValue will work, but when PDTs are + // lowered. It will be required to go down the slice to propagate the length + // parameters. + return fir::factory::componentToExtendedValue(builder, loc, element); +} + +/// Can the assignment of this record type be implement with a simple memory +/// copy ? +static bool recordTypeCanBeMemCopied(fir::RecordType recordType) { + if (fir::hasDynamicSize(recordType)) + return false; + for (auto [_, fieldType] : recordType.getTypeList()) { + // Derived type component may have user assignment (so far, we cannot tell + // in FIR, so assume it is always the case, TODO: get the actual info). + if (fir::unwrapSequenceType(fieldType).isa()) + return false; + // Allocatable components need deep copy. + if (auto boxType = fieldType.dyn_cast()) + if (boxType.getEleTy().isa()) + return false; + } + // Constant size components without user defined assignment and pointers can + // be memcopied. + return true; +} + +void fir::factory::genRecordAssignment(fir::FirOpBuilder &builder, + mlir::Location loc, + const fir::ExtendedValue &lhs, + const fir::ExtendedValue &rhs) { + assert(lhs.rank() == 0 && rhs.rank() == 0 && "assume scalar assignment"); + auto baseTy = fir::dyn_cast_ptrOrBoxEleTy(fir::getBase(lhs).getType()); + assert(baseTy && "must be a memory type"); + // Box operands may be polymorphic, it is not entirely clear from 10.2.1.3 + // if the assignment is performed on the dynamic of declared type. Use the + // runtime assuming it is performed on the dynamic type. + bool hasBoxOperands = fir::getBase(lhs).getType().isa() || + fir::getBase(rhs).getType().isa(); + auto recTy = baseTy.dyn_cast(); + assert(recTy && "must be a record type"); + if (hasBoxOperands || !recordTypeCanBeMemCopied(recTy)) { + auto to = fir::getBase(builder.createBox(loc, lhs)); + auto from = fir::getBase(builder.createBox(loc, rhs)); + // The runtime entry point may modify the LHS descriptor if it is + // an allocatable. Allocatable assignment is handle elsewhere in lowering, + // so just create a fir.ref> from the fir.box to comply with the + // runtime interface, but assume the fir.box is unchanged. + // TODO: does this holds true with polymorphic entities ? + auto toMutableBox = builder.createTemporary(loc, to.getType()); + builder.create(loc, to, toMutableBox); + fir::runtime::genAssign(builder, loc, toMutableBox, from); + return; + } + // Otherwise, the derived type has compile time constant size and for which + // the component by component assignment can be replaced by a memory copy. + auto rhsVal = fir::getBase(rhs); + if (fir::isa_ref_type(rhsVal.getType())) + rhsVal = builder.create(loc, rhsVal); + builder.create(loc, rhsVal, fir::getBase(lhs)); +} + +mlir::TupleType +fir::factory::getRaggedArrayHeaderType(fir::FirOpBuilder &builder) { + auto i64Ty = builder.getIntegerType(64); + auto arrTy = fir::SequenceType::get(builder.getIntegerType(8), 1); + auto buffTy = fir::HeapType::get(arrTy); + auto extTy = fir::SequenceType::get(i64Ty, 1); + auto shTy = fir::HeapType::get(extTy); + return mlir::TupleType::get(builder.getContext(), {i64Ty, buffTy, shTy}); +} diff --git a/flang/lib/Optimizer/Builder/Runtime/Assign.cpp b/flang/lib/Optimizer/Builder/Runtime/Assign.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Builder/Runtime/Assign.cpp @@ -0,0 +1,26 @@ +//===-- Assign.cpp -- generate assignment runtime API calls ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/Runtime/Assign.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Builder/Runtime/RTBuilder.h" +#include "flang/Runtime/assign.h" + +using namespace Fortran::runtime; + +void fir::runtime::genAssign(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value destBox, mlir::Value sourceBox) { + auto func = fir::runtime::getRuntimeFunc(loc, builder); + auto fTy = func.getType(); + auto sourceFile = fir::factory::locationToFilename(builder, loc); + auto sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(3)); + auto args = fir::runtime::createArguments(builder, loc, fTy, destBox, + sourceBox, sourceFile, sourceLine); + builder.create(loc, func, args); +} diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -1144,6 +1144,10 @@ // GlobalOp //===----------------------------------------------------------------------===// +mlir::Type fir::GlobalOp::resultType() { + return wrapAllocaResultType(getType()); +} + static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { // Parse the optional linkage llvm::StringRef linkage; @@ -1311,10 +1315,6 @@ // GlobalLenOp //===----------------------------------------------------------------------===// -mlir::Type fir::GlobalOp::resultType() { - return wrapAllocaResultType(getType()); -} - static mlir::ParseResult parseGlobalLenOp(mlir::OpAsmParser &parser, mlir::OperationState &result) { llvm::StringRef fieldName; diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -200,15 +200,15 @@ mlir::Type dyn_cast_ptrEleTy(mlir::Type t) { return llvm::TypeSwitch(t) - .Case( - [](auto p) { return p.getEleTy(); }) + .Case([](auto p) { return p.getEleTy(); }) .Default([](mlir::Type) { return mlir::Type{}; }); } mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) { return llvm::TypeSwitch(t) - .Case( - [](auto p) { return p.getEleTy(); }) + .Case([](auto p) { return p.getEleTy(); }) .Case([](auto p) { auto eleTy = p.getEleTy(); if (auto ty = fir::dyn_cast_ptrEleTy(eleTy)) @@ -471,6 +471,19 @@ printer << getMnemonic() << "<" << getFKind() << '>'; } +//===----------------------------------------------------------------------===// +// LLVMPointerType +//===----------------------------------------------------------------------===// + +// `llvm_ptr` `<` type `>` +mlir::Type fir::LLVMPointerType::parse(mlir::DialectAsmParser &parser) { + return parseTypeSingleton(parser); +} + +void fir::LLVMPointerType::print(mlir::DialectAsmPrinter &printer) const { + printer << getMnemonic() << "<" << getEleTy() << '>'; +} + //===----------------------------------------------------------------------===// // PointerType //===----------------------------------------------------------------------===// @@ -865,7 +878,7 @@ void FIROpsDialect::registerTypes() { addTypes(); + LLVMPointerType, PointerType, RealType, RecordType, ReferenceType, + SequenceType, ShapeType, ShapeShiftType, ShiftType, SliceType, + TypeDescType, fir::VectorType>(); } diff --git a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp @@ -0,0 +1,891 @@ +//===-- ArrayValueCopy.cpp ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "flang/Optimizer/Builder/BoxValue.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Support/FIRContext.h" +#include "flang/Optimizer/Transforms/Factory.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "flang-array-value-copy" + +using namespace fir; + +using OperationUseMapT = llvm::DenseMap; + +namespace { + +/// Array copy analysis. +/// Perform an interference analysis between array values. +/// +/// Lowering will generate a sequence of the following form. +/// ```mlir +/// %a_1 = fir.array_load %array_1(%shape) : ... +/// ... +/// %a_j = fir.array_load %array_j(%shape) : ... +/// ... +/// %a_n = fir.array_load %array_n(%shape) : ... +/// ... +/// %v_i = fir.array_fetch %a_i, ... +/// %a_j1 = fir.array_update %a_j, ... +/// ... +/// fir.array_merge_store %a_j, %a_jn to %array_j : ... +/// ``` +/// +/// The analysis is to determine if there are any conflicts. A conflict is when +/// one the following cases occurs. +/// +/// 1. There is an `array_update` to an array value, a_j, such that a_j was +/// loaded from the same array memory reference (array_j) but with a different +/// shape as the other array values a_i, where i != j. [Possible overlapping +/// arrays.] +/// +/// 2. There is either an array_fetch or array_update of a_j with a different +/// set of index values. [Possible loop-carried dependence.] +/// +/// If none of the array values overlap in storage and the accesses are not +/// loop-carried, then the arrays are conflict-free and no copies are required. +class ArrayCopyAnalysis { +public: + using ConflictSetT = llvm::SmallPtrSet; + using UseSetT = llvm::SmallPtrSet; + using LoadMapSetsT = llvm::DenseMap; + + ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { + construct(op->getRegions()); + } + + mlir::Operation *getOperation() const { return operation; } + + /// Return true iff the `array_merge_store` has potential conflicts. + bool hasPotentialConflict(mlir::Operation *op) const { + LLVM_DEBUG(llvm::dbgs() + << "looking for a conflict on " << *op + << " and the set has a total of " << conflicts.size() << '\n'); + return conflicts.contains(op); + } + + /// Return the use map. The use map maps array fetch and update operations + /// back to the array load that is the original source of the array value. + const OperationUseMapT &getUseMap() const { return useMap; } + + /// For ArrayLoad `load`, return the transitive set of all OpOperands. + UseSetT getLoadUseSet(mlir::Operation *load) const { + assert(loadMapSets.count(load) && "analysis missed an array load?"); + return loadMapSets.lookup(load); + } + + /// Get all the array value operations that use the original array value + /// as passed to `store`. + void arrayAccesses(llvm::SmallVectorImpl &accesses, + ArrayLoadOp load); + +private: + void construct(mlir::MutableArrayRef regions); + + mlir::Operation *operation; // operation that analysis ran upon + ConflictSetT conflicts; // set of conflicts (loads and merge stores) + OperationUseMapT useMap; + LoadMapSetsT loadMapSets; +}; +} // namespace + +namespace { +/// Helper class to collect all array operations that produced an array value. +class ReachCollector { +public: + ReachCollector(llvm::SmallVectorImpl &reach, + mlir::Region *loopRegion) + : reach{reach}, loopRegion{loopRegion} {} + + void collectArrayAccessFrom(mlir::Operation *op, mlir::ValueRange range) { + if (range.empty()) { + collectArrayAccessFrom(op, mlir::Value{}); + return; + } + for (auto v : range) + collectArrayAccessFrom(v); + } + + void collectArrayAccessFrom(mlir::Operation *op, mlir::Value val) { + // `val` is defined by an Op, process the defining Op. + // If `val` is defined by a region containing Op, we want to drill down + // and through that Op's region(s). + LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n'); + auto popFn = [&](auto rop) { + assert(val && "op must have a result value"); + auto resNum = val.cast().getResultNumber(); + llvm::SmallVector results; + rop.resultToSourceOps(results, resNum); + for (auto u : results) + collectArrayAccessFrom(u); + }; + if (auto rop = mlir::dyn_cast(op)) { + popFn(rop); + return; + } + if (auto rop = mlir::dyn_cast(op)) { + popFn(rop); + return; + } + if (auto rop = mlir::dyn_cast(op)) { + popFn(rop); + return; + } + if (auto box = mlir::dyn_cast(op)) { + for (auto *user : box.memref().getUsers()) + if (user != op) + collectArrayAccessFrom(user, user->getResults()); + return; + } + if (auto mergeStore = mlir::dyn_cast(op)) { + if (opIsInsideLoops(mergeStore)) + collectArrayAccessFrom(mergeStore.sequence()); + return; + } + + if (mlir::isa(op)) { + // Look for any stores inside the loops, and collect an array operation + // that produced the value being stored to it. + for (auto *user : op->getUsers()) + if (auto store = mlir::dyn_cast(user)) + if (opIsInsideLoops(store)) + collectArrayAccessFrom(store.value()); + return; + } + + // Otherwise, Op does not contain a region so just chase its operands. + if (mlir::isa( + op)) { + LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n"); + reach.emplace_back(op); + } + // Array modify assignment is performed on the result. So the analysis + // must look at the what is done with the result. + if (mlir::isa(op)) + for (auto *user : op->getResult(0).getUsers()) + followUsers(user); + + for (auto u : op->getOperands()) + collectArrayAccessFrom(u); + } + + void collectArrayAccessFrom(mlir::BlockArgument ba) { + auto *parent = ba.getOwner()->getParentOp(); + // If inside an Op holding a region, the block argument corresponds to an + // argument passed to the containing Op. + auto popFn = [&](auto rop) { + collectArrayAccessFrom(rop.blockArgToSourceOp(ba.getArgNumber())); + }; + if (auto rop = mlir::dyn_cast(parent)) { + popFn(rop); + return; + } + if (auto rop = mlir::dyn_cast(parent)) { + popFn(rop); + return; + } + // Otherwise, a block argument is provided via the pred blocks. + for (auto *pred : ba.getOwner()->getPredecessors()) { + auto u = pred->getTerminator()->getOperand(ba.getArgNumber()); + collectArrayAccessFrom(u); + } + } + + // Recursively trace operands to find all array operations relating to the + // values merged. + void collectArrayAccessFrom(mlir::Value val) { + if (!val || visited.contains(val)) + return; + visited.insert(val); + + // Process a block argument. + if (auto ba = val.dyn_cast()) { + collectArrayAccessFrom(ba); + return; + } + + // Process an Op. + if (auto *op = val.getDefiningOp()) { + collectArrayAccessFrom(op, val); + return; + } + + fir::emitFatalError(val.getLoc(), "unhandled value"); + } + + /// Return all ops that produce the array value that is stored into the + /// `array_merge_store`. + static void reachingValues(llvm::SmallVectorImpl &reach, + mlir::Value seq) { + reach.clear(); + mlir::Region *loopRegion = nullptr; + if (auto doLoop = + mlir::dyn_cast_or_null(seq.getDefiningOp())) + loopRegion = &doLoop->getRegion(0); + ReachCollector collector(reach, loopRegion); + collector.collectArrayAccessFrom(seq); + } + +private: + /// Is \op inside the loop nest region ? + /// FIXME: replace this structural dependence with graph properties. + bool opIsInsideLoops(mlir::Operation *op) const { + auto *region = op->getParentRegion(); + while (region) { + if (region == loopRegion) + return true; + region = region->getParentRegion(); + } + return false; + } + + /// Recursively trace the use of an operation results, calling + /// collectArrayAccessFrom on the direct and indirect user operands. + void followUsers(mlir::Operation *op) { + for (auto userOperand : op->getOperands()) + collectArrayAccessFrom(userOperand); + // Go through potential converts/coordinate_op. + for (auto indirectUser : op->getUsers()) + followUsers(indirectUser); + } + + llvm::SmallVectorImpl &reach; + llvm::SmallPtrSet visited; + /// Region of the loops nest that produced the array value. + mlir::Region *loopRegion; +}; +} // namespace + +/// Find all the array operations that access the array value that is loaded by +/// the array load operation, `load`. +void ArrayCopyAnalysis::arrayAccesses( + llvm::SmallVectorImpl &accesses, ArrayLoadOp load) { + accesses.clear(); + auto lmIter = loadMapSets.find(load); + if (lmIter != loadMapSets.end()) { + for (auto *opnd : lmIter->second) { + auto *owner = opnd->getOwner(); + if (mlir::isa(owner)) + accesses.push_back(owner); + } + return; + } + + UseSetT visited; + llvm::SmallVector queue; // uses of ArrayLoad[orig] + + auto appendToQueue = [&](mlir::Value val) { + for (auto &use : val.getUses()) + if (!visited.count(&use)) { + visited.insert(&use); + queue.push_back(&use); + } + }; + + // Build the set of uses of `original`. + // let USES = { uses of original fir.load } + appendToQueue(load); + + // Process the worklist until done. + while (!queue.empty()) { + auto *operand = queue.pop_back_val(); + auto *owner = operand->getOwner(); + if (!owner) + continue; + auto structuredLoop = [&](auto ro) { + if (auto blockArg = ro.iterArgToBlockArg(operand->get())) { + auto arg = blockArg.getArgNumber(); + auto output = ro.getResult(ro.finalValue() ? arg : arg - 1); + appendToQueue(output); + appendToQueue(blockArg); + } + }; + auto branchOp = [&](mlir::Block *dest, auto operands) { + for (auto i : llvm::enumerate(operands)) + if (operand->get() == i.value()) { + auto blockArg = dest->getArgument(i.index()); + appendToQueue(blockArg); + } + }; + // Thread uses into structured loop bodies and return value uses. + if (auto ro = mlir::dyn_cast(owner)) { + structuredLoop(ro); + } else if (auto ro = mlir::dyn_cast(owner)) { + structuredLoop(ro); + } else if (auto rs = mlir::dyn_cast(owner)) { + // Thread any uses of fir.if that return the marked array value. + auto *parent = rs->getParentRegion()->getParentOp(); + if (auto ifOp = mlir::dyn_cast(parent)) + appendToQueue(ifOp.getResult(operand->getOperandNumber())); + } else if (mlir::isa(owner)) { + // Keep track of array value fetches. + LLVM_DEBUG(llvm::dbgs() + << "add fetch {" << *owner << "} to array value set\n"); + accesses.push_back(owner); + } else if (auto update = mlir::dyn_cast(owner)) { + // Keep track of array value updates and thread the return value uses. + LLVM_DEBUG(llvm::dbgs() + << "add update {" << *owner << "} to array value set\n"); + accesses.push_back(owner); + appendToQueue(update.getResult()); + } else if (auto update = mlir::dyn_cast(owner)) { + // Keep track of array value modification and thread the return value + // uses. + LLVM_DEBUG(llvm::dbgs() + << "add modify {" << *owner << "} to array value set\n"); + accesses.push_back(owner); + appendToQueue(update.getResult(1)); + } else if (auto br = mlir::dyn_cast(owner)) { + branchOp(br.getDest(), br.destOperands()); + } else if (auto br = mlir::dyn_cast(owner)) { + branchOp(br.getTrueDest(), br.getTrueOperands()); + branchOp(br.getFalseDest(), br.getFalseOperands()); + } else if (mlir::isa(owner)) { + // do nothing + } else { + llvm::report_fatal_error("array value reached unexpected op"); + } + } + loadMapSets.insert({load, visited}); +} + +/// Is there a conflict between the array value that was updated and to be +/// stored to `st` and the set of arrays loaded (`reach`) and used to compute +/// the updated value? +static bool conflictOnLoad(llvm::ArrayRef reach, + ArrayMergeStoreOp st) { + mlir::Value load; + auto addr = st.memref(); + auto stEleTy = fir::dyn_cast_ptrOrBoxEleTy(addr.getType()); + for (auto *op : reach) + if (auto ld = mlir::dyn_cast(op)) { + auto ldTy = ld.memref().getType(); + if (auto boxTy = ldTy.dyn_cast()) + ldTy = boxTy.getEleTy(); + if (ldTy.isa() && stEleTy == dyn_cast_ptrEleTy(ldTy)) + return true; + if (ld.memref() == addr) { + if (ld.getResult() != st.original()) + return true; + if (load) + return true; + load = ld; + } + } + return false; +} + +static bool conflictOnMerge(llvm::ArrayRef accesses) { + if (accesses.size() < 2) + return false; + llvm::SmallVector indices; + LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << accesses.size() + << " accesses on the list\n"); + for (auto *op : accesses) { + llvm::SmallVector compareVector; + if (auto u = mlir::dyn_cast(op)) { + if (indices.empty()) { + indices = u.indices(); + continue; + } + compareVector = u.indices(); + } else if (auto f = mlir::dyn_cast(op)) { + if (indices.empty()) { + indices = f.indices(); + continue; + } + compareVector = f.indices(); + } else if (auto f = mlir::dyn_cast(op)) { + if (indices.empty()) { + indices = f.indices(); + continue; + } + compareVector = f.indices(); + } else { + mlir::emitError(op->getLoc(), "unexpected operation in analysis"); + } + if (compareVector.size() != indices.size() || + llvm::any_of(llvm::zip(compareVector, indices), [&](auto pair) { + return std::get<0>(pair) != std::get<1>(pair); + })) + return true; + LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n"); + } + return false; +} + +// Are either of types of conflicts present? +inline bool conflictDetected(llvm::ArrayRef reach, + llvm::ArrayRef accesses, + ArrayMergeStoreOp st) { + return conflictOnLoad(reach, st) || conflictOnMerge(accesses); +} + +/// Constructor of the array copy analysis. +/// This performs the analysis and saves the intermediate results. +void ArrayCopyAnalysis::construct(mlir::MutableArrayRef regions) { + for (auto ®ion : regions) + for (auto &block : region.getBlocks()) + for (auto &op : block.getOperations()) { + if (op.getNumRegions()) + construct(op.getRegions()); + if (auto st = mlir::dyn_cast(op)) { + llvm::SmallVector values; + ReachCollector::reachingValues(values, st.sequence()); + llvm::SmallVector accesses; + arrayAccesses(accesses, + mlir::cast(st.original().getDefiningOp())); + if (conflictDetected(values, accesses, st)) { + LLVM_DEBUG(llvm::dbgs() + << "CONFLICT: copies required for " << st << '\n' + << " adding conflicts on: " << op << " and " + << st.original() << '\n'); + conflicts.insert(&op); + conflicts.insert(st.original().getDefiningOp()); + } + auto *ld = st.original().getDefiningOp(); + LLVM_DEBUG(llvm::dbgs() + << "map: adding {" << *ld << " -> " << st << "}\n"); + useMap.insert({ld, &op}); + } else if (auto load = mlir::dyn_cast(op)) { + llvm::SmallVector accesses; + arrayAccesses(accesses, load); + LLVM_DEBUG(llvm::dbgs() << "process load: " << load + << ", accesses: " << accesses.size() << '\n'); + for (auto *acc : accesses) { + LLVM_DEBUG(llvm::dbgs() << " access: " << *acc << '\n'); + if (mlir::isa(acc)) { + if (useMap.count(acc)) { + mlir::emitError( + load.getLoc(), + "The parallel semantics of multiple array_merge_stores per " + "array_load are not supported."); + continue; + } + LLVM_DEBUG(llvm::dbgs() << "map: adding {" << *acc << "} -> {" + << load << "}\n"); + useMap.insert({acc, &op}); + } + } + } + } +} + +namespace { +class ArrayLoadConversion : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(ArrayLoadOp load, + mlir::PatternRewriter &rewriter) const override { + LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n"); + rewriter.replaceOpWithNewOp(load, load.getType()); + return mlir::success(); + } +}; + +class ArrayMergeStoreConversion + : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(ArrayMergeStoreOp store, + mlir::PatternRewriter &rewriter) const override { + LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n"); + rewriter.eraseOp(store); + return mlir::success(); + } +}; +} // namespace + +static mlir::Type getEleTy(mlir::Type ty) { + if (auto t = dyn_cast_ptrEleTy(ty)) + ty = t; + if (auto t = ty.dyn_cast()) + ty = t.getEleTy(); + // FIXME: keep ptr/heap/ref information. + return ReferenceType::get(ty); +} + +// Extract extents from the ShapeOp/ShapeShiftOp into the result vector. +static void getExtents(llvm::SmallVectorImpl &result, + mlir::Value shape) { + auto *shapeOp = shape.getDefiningOp(); + if (auto s = mlir::dyn_cast(shapeOp)) { + auto e = s.getExtents(); + result.insert(result.end(), e.begin(), e.end()); + return; + } + if (auto s = mlir::dyn_cast(shapeOp)) { + auto e = s.getExtents(); + result.insert(result.end(), e.begin(), e.end()); + return; + } + llvm::report_fatal_error("not a fir.shape/fir.shape_shift op"); +} + +// Place the extents of the array loaded by an ArrayLoadOp into the result +// vector and return a ShapeOp/ShapeShiftOp with the corresponding extents. If +// the ArrayLoadOp is loading a fir.box, code will be generated to read the +// extents from the fir.box, and a the retunred ShapeOp is built with the read +// extents. +// Otherwise, the extents will be extracted from the ShapeOp/ShapeShiftOp +// argument of the ArrayLoadOp that is returned. +static mlir::Value +getOrReadExtentsAndShapeOp(mlir::Location loc, mlir::PatternRewriter &rewriter, + fir::ArrayLoadOp loadOp, + llvm::SmallVectorImpl &result) { + assert(result.empty()); + if (auto boxTy = loadOp.memref().getType().dyn_cast()) { + auto rank = fir::dyn_cast_ptrOrBoxEleTy(boxTy) + .cast() + .getDimension(); + auto idxTy = rewriter.getIndexType(); + for (decltype(rank) dim = 0; dim < rank; ++dim) { + auto dimVal = rewriter.create(loc, dim); + auto dimInfo = rewriter.create(loc, idxTy, idxTy, idxTy, + loadOp.memref(), dimVal); + result.emplace_back(dimInfo.getResult(1)); + } + auto shapeType = fir::ShapeType::get(rewriter.getContext(), rank); + return rewriter.create(loc, shapeType, result); + } + getExtents(result, loadOp.shape()); + return loadOp.shape(); +} + +static mlir::Type toRefType(mlir::Type ty) { + if (fir::isa_ref_type(ty)) + return ty; + return fir::ReferenceType::get(ty); +} + +static mlir::Value +genCoorOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type eleTy, + mlir::Type resTy, mlir::Value alloc, mlir::Value shape, + mlir::Value slice, mlir::ValueRange indices, + mlir::ValueRange typeparams, bool skipOrig = false) { + llvm::SmallVector originated; + if (skipOrig) + originated.assign(indices.begin(), indices.end()); + else + originated = fir::factory::originateIndices(loc, rewriter, alloc.getType(), + shape, indices); + auto seqTy = fir::dyn_cast_ptrOrBoxEleTy(alloc.getType()); + assert(seqTy && seqTy.isa()); + const auto dimension = seqTy.cast().getDimension(); + mlir::Value result = rewriter.create( + loc, eleTy, alloc, shape, slice, + llvm::ArrayRef{originated}.take_front(dimension), + typeparams); + if (dimension < originated.size()) + result = rewriter.create( + loc, resTy, result, + llvm::ArrayRef{originated}.drop_front(dimension)); + return result; +} + +namespace { +/// Conversion of fir.array_update and fir.array_modify Ops. +/// If there is a conflict for the update, then we need to perform a +/// copy-in/copy-out to preserve the original values of the array. If there is +/// no conflict, then it is save to eschew making any copies. +template +class ArrayUpdateConversionBase : public mlir::OpRewritePattern { +public: + explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx, + const ArrayCopyAnalysis &a, + const OperationUseMapT &m) + : mlir::OpRewritePattern{ctx}, analysis{a}, useMap{m} {} + + static llvm::SmallVector recoverTypeParams(mlir::Value val) { + auto *op = val.getDefiningOp(); + if (!fir::hasDynamicSize(fir::dyn_cast_ptrEleTy(val.getType()))) + return {}; + if (auto co = mlir::dyn_cast(op)) + return recoverTypeParams(co.value()); + if (auto ao = mlir::dyn_cast(op)) + return {ao.typeparams().begin(), ao.typeparams().end()}; + if (auto ao = mlir::dyn_cast(op)) + return {ao.typeparams().begin(), ao.typeparams().end()}; + if (auto ao = mlir::dyn_cast(op)) + return {ao.typeparams().begin(), ao.typeparams().end()}; + if (auto ao = mlir::dyn_cast(op)) + return {ao.typeparams().begin(), ao.typeparams().end()}; + if (auto ao = mlir::dyn_cast(op)) + return {ao.typeparams().begin(), ao.typeparams().end()}; + if (auto ao = mlir::dyn_cast(op)) + return {ao.typeparams().begin(), ao.typeparams().end()}; + if (auto ao = mlir::dyn_cast(op)) + return {ao.typeparams().begin(), ao.typeparams().end()}; + llvm::report_fatal_error("unexpected buffer"); + } + + static mlir::Value recoverCharLen(mlir::Value val) { + auto params = recoverTypeParams(val); + return params.empty() ? mlir::Value{} : params[0]; + } + + void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter, + mlir::Value dst, mlir::Value src, mlir::Value shapeOp, + mlir::Type arrTy) const { + auto insPt = rewriter.saveInsertionPoint(); + llvm::SmallVector indices; + llvm::SmallVector extents; + getExtents(extents, shapeOp); + // Build loop nest from column to row. + for (auto sh : llvm::reverse(extents)) { + auto idxTy = rewriter.getIndexType(); + auto ubi = rewriter.create(loc, idxTy, sh); + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto ub = rewriter.create(loc, idxTy, ubi, one); + auto loop = rewriter.create(loc, zero, ub, one); + rewriter.setInsertionPointToStart(loop.getBody()); + indices.push_back(loop.getInductionVar()); + } + // Reverse the indices so they are in column-major order. + std::reverse(indices.begin(), indices.end()); + auto ty = getEleTy(arrTy); + auto fromAddr = rewriter.create( + loc, ty, src, shapeOp, mlir::Value{}, + fir::factory::originateIndices(loc, rewriter, src.getType(), shapeOp, + indices), + mlir::ValueRange{}); + auto load = rewriter.create(loc, fromAddr); + auto toAddr = rewriter.create( + loc, ty, dst, shapeOp, mlir::Value{}, + fir::factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, + indices), + mlir::ValueRange{}); + rewriter.create(loc, load, toAddr); + rewriter.restoreInsertionPoint(insPt); + } + + /// Copy the RHS element into the LHS and insert copy-in/copy-out between a + /// temp and the LHS if the analysis found potential overlaps between the RHS + /// and LHS arrays. The element copy generator must be provided through \p + /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp. + /// Returns the address of the LHS element inside the loop and the LHS + /// ArrayLoad result. + std::pair + materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter, + ArrayOp update, + const std::function &assignElement, + mlir::Type lhsEltRefType) const { + auto *op = update.getOperation(); + auto *loadOp = useMap.lookup(op); + auto load = mlir::cast(loadOp); + LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n"); + if (analysis.hasPotentialConflict(loadOp)) { + // If there is a conflict between the arrays, then we copy the lhs array + // to a temporary, update the temporary, and copy the temporary back to + // the lhs array. This yields Fortran's copy-in copy-out array semantics. + LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n"); + rewriter.setInsertionPoint(loadOp); + // Copy in. + llvm::SmallVector extents; + auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents); + auto allocmem = rewriter.create( + loc, dyn_cast_ptrOrBoxEleTy(load.memref().getType()), + load.typeparams(), extents); + genArrayCopy(load.getLoc(), rewriter, allocmem, load.memref(), shapeOp, + load.getType()); + rewriter.setInsertionPoint(op); + auto coor = genCoorOp( + rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem, + shapeOp, load.slice(), update.indices(), load.typeparams(), + update->hasAttr(fir::factory::attrFortranArrayOffsets())); + assignElement(coor); + auto *storeOp = useMap.lookup(loadOp); + auto store = mlir::cast(storeOp); + rewriter.setInsertionPoint(storeOp); + // Copy out. + genArrayCopy(store.getLoc(), rewriter, store.memref(), allocmem, shapeOp, + load.getType()); + rewriter.create(loc, allocmem); + return {coor, load.getResult()}; + } + // Otherwise, when there is no conflict (a possible loop-carried + // dependence), the lhs array can be updated in place. + LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n"); + rewriter.setInsertionPoint(op); + auto coorTy = getEleTy(load.getType()); + auto coor = genCoorOp( + rewriter, loc, coorTy, lhsEltRefType, load.memref(), load.shape(), + load.slice(), update.indices(), load.typeparams(), + update->hasAttr(fir::factory::attrFortranArrayOffsets())); + assignElement(coor); + return {coor, load.getResult()}; + } + +private: + const ArrayCopyAnalysis &analysis; + const OperationUseMapT &useMap; +}; + +class ArrayUpdateConversion : public ArrayUpdateConversionBase { +public: + explicit ArrayUpdateConversion(mlir::MLIRContext *ctx, + const ArrayCopyAnalysis &a, + const OperationUseMapT &m) + : ArrayUpdateConversionBase{ctx, a, m} {} + + mlir::LogicalResult + matchAndRewrite(ArrayUpdateOp update, + mlir::PatternRewriter &rewriter) const override { + auto loc = update.getLoc(); + auto assignElement = [&](mlir::Value coor) { + mlir::Value input = update.merge(); + if (mlir::Type inEleTy = fir::dyn_cast_ptrEleTy(input.getType())) { + if (auto inChrTy = inEleTy.dyn_cast()) { + assert(fir::unwrapSequenceType(update.getType()) + .isa()); + fir::factory::genCharacterCopy(input, recoverCharLen(input), coor, + recoverCharLen(coor), rewriter, loc); + } else if (inEleTy.isa()) { + fir::FirOpBuilder builder( + rewriter, + fir::getKindMapping(update->getParentOfType())); + if (!update.typeparams().empty()) { + auto boxTy = fir::BoxType::get(inEleTy); + mlir::Value emptyShape, emptySlice; + auto lhs = rewriter.create( + loc, boxTy, coor, emptyShape, emptySlice, update.typeparams()); + auto rhs = rewriter.create( + loc, boxTy, input, emptyShape, emptySlice, update.typeparams()); + fir::factory::genRecordAssignment(builder, loc, fir::BoxValue(lhs), + fir::BoxValue(rhs)); + } else { + fir::factory::genRecordAssignment(builder, loc, coor, input); + } + } else { + llvm::report_fatal_error("not a legal reference type"); + } + } else { + rewriter.create(loc, input, coor); + } + }; + auto lhsEltRefType = toRefType(update.merge().getType()); + auto [_, lhsLoadResult] = materializeAssignment( + loc, rewriter, update, assignElement, lhsEltRefType); + update.replaceAllUsesWith(lhsLoadResult); + rewriter.replaceOp(update, lhsLoadResult); + return mlir::success(); + } +}; + +class ArrayModifyConversion : public ArrayUpdateConversionBase { +public: + explicit ArrayModifyConversion(mlir::MLIRContext *ctx, + const ArrayCopyAnalysis &a, + const OperationUseMapT &m) + : ArrayUpdateConversionBase{ctx, a, m} {} + + mlir::LogicalResult + matchAndRewrite(ArrayModifyOp modify, + mlir::PatternRewriter &rewriter) const override { + auto loc = modify.getLoc(); + auto assignElement = [](mlir::Value) { + // Assignment already materialized by lowering using lhs element address. + }; + auto lhsEltRefType = modify.getResult(0).getType(); + auto [lhsEltCoor, lhsLoadResult] = materializeAssignment( + loc, rewriter, modify, assignElement, lhsEltRefType); + modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult}); + rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult}); + return mlir::success(); + } +}; + +class ArrayFetchConversion : public mlir::OpRewritePattern { +public: + explicit ArrayFetchConversion(mlir::MLIRContext *ctx, + const OperationUseMapT &m) + : OpRewritePattern{ctx}, useMap{m} {} + + mlir::LogicalResult + matchAndRewrite(ArrayFetchOp fetch, + mlir::PatternRewriter &rewriter) const override { + auto *op = fetch.getOperation(); + rewriter.setInsertionPoint(op); + auto load = mlir::cast(useMap.lookup(op)); + auto loc = fetch.getLoc(); + auto coor = + genCoorOp(rewriter, loc, getEleTy(load.getType()), + toRefType(fetch.getType()), load.memref(), load.shape(), + load.slice(), fetch.indices(), load.typeparams(), + fetch->hasAttr(fir::factory::attrFortranArrayOffsets())); + if (fir::isa_ref_type(fetch.getType())) + rewriter.replaceOp(fetch, coor); + else + rewriter.replaceOpWithNewOp(fetch, coor); + return mlir::success(); + } + +private: + const OperationUseMapT &useMap; +}; +} // namespace + +namespace { +class ArrayValueCopyConverter + : public ArrayValueCopyBase { +public: + void runOnFunction() override { + auto func = getFunction(); + LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '" + << func.getName() << "'\n"); + auto *context = &getContext(); + + // Perform the conflict analysis. + auto &analysis = getAnalysis(); + const auto &useMap = analysis.getUseMap(); + + mlir::OwningRewritePatternList patterns1(context); + patterns1.insert(context, useMap); + patterns1.insert(context, analysis, useMap); + patterns1.insert(context, analysis, useMap); + mlir::ConversionTarget target(*context); + target.addLegalDialect(); + target.addIllegalOp(); + // Rewrite the array fetch and array update ops. + if (mlir::failed( + mlir::applyPartialConversion(func, target, std::move(patterns1)))) { + mlir::emitError(mlir::UnknownLoc::get(context), + "failure in array-value-copy pass, phase 1"); + signalPassFailure(); + } + + mlir::OwningRewritePatternList patterns2(context); + patterns2.insert(context); + patterns2.insert(context); + target.addIllegalOp(); + if (mlir::failed( + mlir::applyPartialConversion(func, target, std::move(patterns2)))) { + mlir::emitError(mlir::UnknownLoc::get(context), + "failure in array-value-copy pass, phase 2"); + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr fir::createArrayValueCopyPass() { + return std::make_unique(); +} diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -3,17 +3,20 @@ AffinePromotion.cpp AffineDemotion.cpp CharacterConversion.cpp + ArrayValueCopy.cpp Inliner.cpp ExternalNameConversion.cpp RewriteLoop.cpp DEPENDS + FIRBuilder FIRDialect FIRSupport FIROptTransformsPassIncGen RewritePatternsIncGen LINK_LIBS + FIRBuilder FIRDialect MLIRAffineToStandard MLIRLLVMIR diff --git a/flang/test/Fir/array-modify.fir b/flang/test/Fir/array-modify.fir new file mode 100644 --- /dev/null +++ b/flang/test/Fir/array-modify.fir @@ -0,0 +1,130 @@ +// Test array-copy-value pass (copy elision) with fir.array_modify +// RUN: fir-opt %s --array-value-copy | FileCheck %s + +// Test user_defined_assignment(arg0(:), arg1(:)) +func @no_overlap(%arg0: !fir.ref>, %arg1: !fir.ref>) { + %c100 = arith.constant 100 : index + %c99 = arith.constant 99 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = fir.alloca f32 + %1 = fir.shape %c100 : (index) -> !fir.shape<1> + %2 = fir.array_load %arg0(%1) : (!fir.ref>, !fir.shape<1>) -> !fir.array<100xf32> + %3 = fir.array_load %arg1(%1) : (!fir.ref>, !fir.shape<1>) -> !fir.array<100xf32> + %4 = fir.do_loop %arg2 = %c0 to %c99 step %c1 unordered iter_args(%arg3 = %2) -> (!fir.array<100xf32>) { + %5 = fir.array_fetch %3, %arg2 : (!fir.array<100xf32>, index) -> f32 + %6:2 = fir.array_modify %arg3, %arg2 : (!fir.array<100xf32>, index) -> (!fir.ref, !fir.array<100xf32>) + fir.store %5 to %0 : !fir.ref + fir.call @user_defined_assignment(%6#0, %0) : (!fir.ref, !fir.ref) -> () + fir.result %6#1 : !fir.array<100xf32> + } + fir.array_merge_store %2, %4 to %arg0 : !fir.array<100xf32>, !fir.array<100xf32>, !fir.ref> + return +} +// CHECK-LABEL: func @no_overlap( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>, +// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref>) { +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 100 : index +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 99 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_6:.*]] = fir.alloca f32 +// CHECK: %[[VAL_7:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_8:.*]] = fir.undefined !fir.array<100xf32> +// CHECK: %[[VAL_9:.*]] = fir.undefined !fir.array<100xf32> +// CHECK: %[[VAL_10:.*]] = fir.do_loop %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_4]] unordered iter_args(%[[VAL_12:.*]] = %[[VAL_8]]) -> (!fir.array<100xf32>) { +// CHECK: %[[VAL_13:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_11]], %[[VAL_13]] : index +// CHECK: %[[VAL_15:.*]] = fir.array_coor %[[VAL_1]](%[[VAL_7]]) %[[VAL_14]] : (!fir.ref>, !fir.shape<1>, index) -> !fir.ref +// CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_15]] : !fir.ref +// CHECK: %[[VAL_17:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_11]], %[[VAL_17]] : index +// CHECK: %[[VAL_19:.*]] = fir.array_coor %[[VAL_0]](%[[VAL_7]]) %[[VAL_18]] : (!fir.ref>, !fir.shape<1>, index) -> !fir.ref +// CHECK: fir.store %[[VAL_16]] to %[[VAL_6]] : !fir.ref +// CHECK: fir.call @user_defined_assignment(%[[VAL_19]], %[[VAL_6]]) : (!fir.ref, !fir.ref) -> () +// CHECK: fir.result %[[VAL_8]] : !fir.array<100xf32> +// CHECK: } +// CHECK: return +// CHECK: } + + +// Test user_defined_assignment(arg0(:), arg0(100:1:-1)) +func @overlap(%arg0: !fir.ref>) { + %c100 = arith.constant 100 : index + %c99 = arith.constant 99 : index + %c1 = arith.constant 1 : index + %c-1 = arith.constant -1 : index + %c0 = arith.constant 0 : index + %0 = fir.alloca f32 + %1 = fir.shape %c100 : (index) -> !fir.shape<1> + %2 = fir.array_load %arg0(%1) : (!fir.ref>, !fir.shape<1>) -> !fir.array<100xf32> + %3 = fir.slice %c100, %c1, %c-1 : (index, index, index) -> !fir.slice<1> + %4 = fir.array_load %arg0(%1) [%3] : (!fir.ref>, !fir.shape<1>, !fir.slice<1>) -> !fir.array<100xf32> + %5 = fir.do_loop %arg1 = %c0 to %c99 step %c1 unordered iter_args(%arg2 = %2) -> (!fir.array<100xf32>) { + %6 = fir.array_fetch %4, %arg1 : (!fir.array<100xf32>, index) -> f32 + %7:2 = fir.array_modify %arg2, %arg1 : (!fir.array<100xf32>, index) -> (!fir.ref, !fir.array<100xf32>) + fir.store %6 to %0 : !fir.ref + fir.call @user_defined_assignment(%7#0, %0) : (!fir.ref, !fir.ref) -> () + fir.result %7#1 : !fir.array<100xf32> + } + fir.array_merge_store %2, %5 to %arg0 : !fir.array<100xf32>, !fir.array<100xf32>, !fir.ref> + return +} +// CHECK-LABEL: func @overlap( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>) { +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 100 : index +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 99 : index +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant -1 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_6:.*]] = fir.alloca f32 +// CHECK: %[[VAL_7:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_8:.*]] = fir.allocmem !fir.array<100xf32>, %[[VAL_1]] +// CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_1]] : (index) -> index +// CHECK: %[[VAL_10:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_11:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_12:.*]] = arith.subi %[[VAL_9]], %[[VAL_11]] : index +// CHECK: fir.do_loop %[[VAL_13:.*]] = %[[VAL_10]] to %[[VAL_12]] step %[[VAL_11]] { +// CHECK: %[[VAL_14:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : index +// CHECK: %[[VAL_16:.*]] = fir.array_coor %[[VAL_0]](%[[VAL_7]]) %[[VAL_15]] : (!fir.ref>, !fir.shape<1>, index) -> !fir.ref +// CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref +// CHECK: %[[VAL_18:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_13]], %[[VAL_18]] : index +// CHECK: %[[VAL_20:.*]] = fir.array_coor %[[VAL_8]](%[[VAL_7]]) %[[VAL_19]] : (!fir.heap>, !fir.shape<1>, index) -> !fir.ref +// CHECK: fir.store %[[VAL_17]] to %[[VAL_20]] : !fir.ref +// CHECK: } +// CHECK: %[[VAL_21:.*]] = fir.undefined !fir.array<100xf32> +// CHECK: %[[VAL_22:.*]] = fir.slice %[[VAL_1]], %[[VAL_3]], %[[VAL_4]] : (index, index, index) -> !fir.slice<1> +// CHECK: %[[VAL_23:.*]] = fir.undefined !fir.array<100xf32> +// CHECK: %[[VAL_24:.*]] = fir.do_loop %[[VAL_25:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_3]] unordered iter_args(%[[VAL_26:.*]] = %[[VAL_21]]) -> (!fir.array<100xf32>) { +// CHECK: %[[VAL_27:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_25]], %[[VAL_27]] : index +// CHECK: %[[VAL_29:.*]] = fir.array_coor %[[VAL_0]](%[[VAL_7]]) {{\[}}%[[VAL_22]]] %[[VAL_28]] : (!fir.ref>, !fir.shape<1>, !fir.slice<1>, index) -> !fir.ref +// CHECK: %[[VAL_30:.*]] = fir.load %[[VAL_29]] : !fir.ref +// CHECK: %[[VAL_31:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_25]], %[[VAL_31]] : index +// CHECK: %[[VAL_33:.*]] = fir.array_coor %[[VAL_8]](%[[VAL_7]]) %[[VAL_32]] : (!fir.heap>, !fir.shape<1>, index) -> !fir.ref +// CHECK: fir.store %[[VAL_30]] to %[[VAL_6]] : !fir.ref +// CHECK: fir.call @user_defined_assignment(%[[VAL_33]], %[[VAL_6]]) : (!fir.ref, !fir.ref) -> () +// CHECK: fir.result %[[VAL_21]] : !fir.array<100xf32> +// CHECK: } +// CHECK: %[[VAL_34:.*]] = fir.convert %[[VAL_1]] : (index) -> index +// CHECK: %[[VAL_35:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_36:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_37:.*]] = arith.subi %[[VAL_34]], %[[VAL_36]] : index +// CHECK: fir.do_loop %[[VAL_38:.*]] = %[[VAL_35]] to %[[VAL_37]] step %[[VAL_36]] { +// CHECK: %[[VAL_39:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_38]], %[[VAL_39]] : index +// CHECK: %[[VAL_41:.*]] = fir.array_coor %[[VAL_8]](%[[VAL_7]]) %[[VAL_40]] : (!fir.heap>, !fir.shape<1>, index) -> !fir.ref +// CHECK: %[[VAL_42:.*]] = fir.load %[[VAL_41]] : !fir.ref +// CHECK: %[[VAL_43:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_38]], %[[VAL_43]] : index +// CHECK: %[[VAL_45:.*]] = fir.array_coor %[[VAL_0]](%[[VAL_7]]) %[[VAL_44]] : (!fir.ref>, !fir.shape<1>, index) -> !fir.ref +// CHECK: fir.store %[[VAL_42]] to %[[VAL_45]] : !fir.ref +// CHECK: } +// CHECK: fir.freemem %[[VAL_8]] : !fir.heap> +// CHECK: return +// CHECK: } + +func private @user_defined_assignment(!fir.ref, !fir.ref) diff --git a/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp b/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp --- a/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp +++ b/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp @@ -144,8 +144,8 @@ auto loc = builder.getUnknownLoc(); auto realTy = mlir::FloatType::getF64(ctx); auto cst = builder.createRealZeroConstant(loc, realTy); - EXPECT_TRUE(mlir::isa(cst.getDefiningOp())); - auto cstOp = dyn_cast(cst.getDefiningOp()); + EXPECT_TRUE(mlir::isa(cst.getDefiningOp())); + auto cstOp = dyn_cast(cst.getDefiningOp()); EXPECT_EQ(realTy, cstOp.getType()); EXPECT_EQ(0u, cstOp.value().cast().getValue().convertToDouble()); }