diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -224,6 +224,115 @@ } }; +// Code shared between insert_value and extract_value Ops. +struct ValueOpCommon { + static mlir::Attribute getValue(mlir::Value value) { + auto *defOp = value.getDefiningOp(); + if (auto v = dyn_cast(defOp)) + return v.value(); + if (auto v = dyn_cast(defOp)) + return v.value(); + llvm_unreachable("must be a constant op"); + return {}; + } + + // Translate the arguments pertaining to any multidimensional array to + // row-major order for LLVM-IR. + static void toRowMajor(SmallVectorImpl &attrs, + mlir::Type ty) { + assert(ty && "type is null"); + const auto end = attrs.size(); + for (std::remove_const_t i = 0; i < end; ++i) { + if (auto seq = ty.dyn_cast()) { + const auto dim = getDimension(seq); + if (dim > 1) { + auto ub = std::min(i + dim, end); + std::reverse(attrs.begin() + i, attrs.begin() + ub); + i += dim - 1; + } + ty = getArrayElementType(seq); + } else if (auto st = ty.dyn_cast()) { + ty = st.getBody()[attrs[i].cast().getInt()]; + } else { + llvm_unreachable("index into invalid type"); + } + } + } + + static llvm::SmallVector + collectIndices(mlir::ConversionPatternRewriter &rewriter, + mlir::ArrayAttr arrAttr) { + llvm::SmallVector attrs; + for (auto i = arrAttr.begin(), e = arrAttr.end(); i != e; ++i) { + if (i->isa()) { + attrs.push_back(*i); + } else { + auto fieldName = i->cast().getValue(); + ++i; + auto ty = i->cast().getValue(); + auto index = ty.cast().getFieldIndex(fieldName); + attrs.push_back(mlir::IntegerAttr::get(rewriter.getI32Type(), index)); + } + } + return attrs; + } + +private: + static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) { + unsigned result = 1; + for (auto eleTy = ty.getElementType().dyn_cast(); + eleTy; + eleTy = eleTy.getElementType().dyn_cast()) + ++result; + return result; + } + + static mlir::Type getArrayElementType(mlir::LLVM::LLVMArrayType ty) { + auto eleTy = ty.getElementType(); + while (auto arrTy = eleTy.dyn_cast()) + eleTy = arrTy.getElementType(); + return eleTy; + } +}; + +/// Extract a subobject value from an ssa-value of aggregate type +struct ExtractValueOpConversion + : public FIROpAndTypeConversion, + public ValueOpCommon { + using FIROpAndTypeConversion::FIROpAndTypeConversion; + + mlir::LogicalResult + doRewrite(fir::ExtractValueOp extractVal, mlir::Type ty, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto attrs = collectIndices(rewriter, extractVal.coor()); + toRowMajor(attrs, adaptor.getOperands()[0].getType()); + auto position = mlir::ArrayAttr::get(extractVal.getContext(), attrs); + rewriter.replaceOpWithNewOp( + extractVal, ty, adaptor.getOperands()[0], position); + return success(); + } +}; + +/// InsertValue is the generalized instruction for the composition of new +/// aggregate type values. +struct InsertValueOpConversion + : public FIROpAndTypeConversion, + public ValueOpCommon { + using FIROpAndTypeConversion::FIROpAndTypeConversion; + + mlir::LogicalResult + doRewrite(fir::InsertValueOp insertVal, mlir::Type ty, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto attrs = collectIndices(rewriter, insertVal.coor()); + toRowMajor(attrs, adaptor.getOperands()[0].getType()); + auto position = mlir::ArrayAttr::get(insertVal.getContext(), attrs); + rewriter.replaceOpWithNewOp( + insertVal, ty, adaptor.getOperands()[0], adaptor.getOperands()[1], + position); + return success(); + } +}; + /// InsertOnRange inserts a value into a sequence over a range of offsets. struct InsertOnRangeOpConversion : public FIROpAndTypeConversion { @@ -318,9 +427,11 @@ auto *context = getModule().getContext(); fir::LLVMTypeConverter typeConverter{getModule()}; mlir::OwningRewritePatternList pattern(context); - pattern.insert(typeConverter); + pattern.insert< + AddrOfOpConversion, ExtractValueOpConversion, HasValueOpConversion, + GlobalOpConversion, InsertOnRangeOpConversion, InsertValueOpConversion, + UndefOpConversion, UnreachableOpConversion, ZeroOpConversion>( + typeConverter); mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, pattern); diff --git a/flang/lib/Optimizer/CodeGen/DescriptorModel.h b/flang/lib/Optimizer/CodeGen/DescriptorModel.h new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/CodeGen/DescriptorModel.h @@ -0,0 +1,140 @@ +//===-- DescriptorModel.h -- model of descriptors for codegen ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// LLVM IR dialect models of C++ types. +// +// This supplies a set of model builders to decompose the C declaration of a +// descriptor (as encoded in ISO_Fortran_binding.h and elsewhere) and +// reconstruct that type in the LLVM IR dialect. +// +// TODO: It is understood that this is deeply incorrect as far as building a +// portability layer for cross-compilation as these reflected types are those of +// the build machine and not necessarily that of either the host or the target. +// This assumption that build == host == target is actually pervasive across the +// compiler. +// +//===----------------------------------------------------------------------===// + +#ifndef OPTIMIZER_DESCRIPTOR_MODEL_H +#define OPTIMIZER_DESCRIPTOR_MODEL_H + +#include "flang/ISO_Fortran_binding.h" +#include "flang/Runtime/descriptor.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "llvm/Support/ErrorHandling.h" +#include + +namespace fir { + +using TypeBuilderFunc = mlir::Type (*)(mlir::MLIRContext *); + +/// Get the LLVM IR dialect model for building a particular C++ type, `T`. +template +TypeBuilderFunc getModel(); + +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::LLVM::LLVMPointerType::get(mlir::IntegerType::get(context, 8)); + }; +} +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, sizeof(unsigned) * 8); + }; +} +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, sizeof(int) * 8); + }; +} +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, sizeof(unsigned long) * 8); + }; +} +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, sizeof(unsigned long long) * 8); + }; +} +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, sizeof(long long) * 8); + }; +} +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, + sizeof(Fortran::ISO::CFI_rank_t) * 8); + }; +} +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, + sizeof(Fortran::ISO::CFI_type_t) * 8); + }; +} +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, + sizeof(Fortran::ISO::CFI_index_t) * 8); + }; +} +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + auto indexTy = getModel()(context); + return mlir::LLVM::LLVMArrayType::get(indexTy, 3); + }; +} +template <> +TypeBuilderFunc +getModel>() { + return getModel(); +} + +//===----------------------------------------------------------------------===// +// Descriptor reflection +//===----------------------------------------------------------------------===// + +/// Get the type model of the field number `Field` in an ISO CFI descriptor. +template +static constexpr TypeBuilderFunc getDescFieldTypeModel() { + Fortran::ISO::Fortran_2018::CFI_cdesc_t dummyDesc{}; + // check that the descriptor is exactly 8 fields + auto [a, b, c, d, e, f, g, h] = dummyDesc; + auto tup = std::tie(a, b, c, d, e, f, g, h); + auto field = std::get(tup); + return getModel(); +} + +/// An extended descriptor is defined by a class in runtime/descriptor.h. The +/// three fields in the class are hard-coded here, unlike the reflection used on +/// the ISO parts, which are a POD. +template +static constexpr TypeBuilderFunc getExtendedDescFieldTypeModel() { + if constexpr (Field == 8) { + return getModel(); + } else if constexpr (Field == 9) { + return getModel(); + } else { + llvm_unreachable("extended ISO descriptor only has 10 fields"); + } +} + +} // namespace fir + +#endif // OPTIMIZER_DESCRIPTOR_MODEL_H diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.h b/flang/lib/Optimizer/CodeGen/TypeConverter.h --- a/flang/lib/Optimizer/CodeGen/TypeConverter.h +++ b/flang/lib/Optimizer/CodeGen/TypeConverter.h @@ -13,6 +13,9 @@ #ifndef FORTRAN_OPTIMIZER_CODEGEN_TYPECONVERTER_H #define FORTRAN_OPTIMIZER_CODEGEN_TYPECONVERTER_H +#include "DescriptorModel.h" +#include "flang/Lower/Todo.h" // remove when TODO's are done +#include "llvm/ADT/StringMap.h" #include "llvm/Support/Debug.h" namespace fir { @@ -26,10 +29,117 @@ LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n"); // Each conversion should return a value of type mlir::Type. + addConversion([&](BoxType box) { return convertBoxType(box); }); + addConversion( + [&](fir::RecordType derived) { return convertRecordType(derived); }); addConversion( [&](fir::ReferenceType ref) { return convertPointerLike(ref); }); addConversion( [&](SequenceType sequence) { return convertSequenceType(sequence); }); + addConversion([&](mlir::TupleType tuple) { + LLVM_DEBUG(llvm::dbgs() << "type convert: " << tuple << '\n'); + llvm::SmallVector inMembers; + tuple.getFlattenedTypes(inMembers); + llvm::SmallVector members; + for (auto mem : inMembers) { + // Prevent fir.box from degenerating to a pointer to a descriptor in the + // context of a tuple type. + if (auto box = mem.dyn_cast()) + members.push_back(convertBoxTypeAsStruct(box)); + else + members.push_back(convertType(mem).cast()); + } + return mlir::LLVM::LLVMStructType::getLiteral(&getContext(), members, + /*isPacked=*/false); + }); + } + + // Is an extended descriptor needed given the element type of a fir.box type ? + // Extended descriptors are required for derived types. + bool requiresExtendedDesc(mlir::Type boxElementType) { + auto eleTy = fir::unwrapSequenceType(boxElementType); + return eleTy.isa(); + } + + // Magic value to indicate we do not know the rank of an entity, either + // because it is assumed rank or because we have not determined it yet. + static constexpr int unknownRank() { return -1; } + + // This corresponds to the descriptor as defined ISO_Fortran_binding.h and the + // addendum defined in descriptor.h. + mlir::Type convertBoxType(BoxType box, int rank = unknownRank()) { + // (buffer*, ele-size, rank, type-descriptor, attribute, [dims]) + SmallVector parts; + mlir::Type ele = box.getEleTy(); + // remove fir.heap/fir.ref/fir.ptr + if (auto removeIndirection = fir::dyn_cast_ptrEleTy(ele)) + ele = removeIndirection; + auto eleTy = convertType(ele); + // buffer* + if (ele.isa() && eleTy.isa()) + parts.push_back(eleTy); + else + parts.push_back(mlir::LLVM::LLVMPointerType::get(eleTy)); + parts.push_back(getDescFieldTypeModel<1>()(&getContext())); + parts.push_back(getDescFieldTypeModel<2>()(&getContext())); + parts.push_back(getDescFieldTypeModel<3>()(&getContext())); + parts.push_back(getDescFieldTypeModel<4>()(&getContext())); + parts.push_back(getDescFieldTypeModel<5>()(&getContext())); + parts.push_back(getDescFieldTypeModel<6>()(&getContext())); + if (rank == unknownRank()) { + if (auto seqTy = ele.dyn_cast()) + rank = seqTy.getDimension(); + else + rank = 0; + } + if (rank > 0) { + auto rowTy = getDescFieldTypeModel<7>()(&getContext()); + parts.push_back(mlir::LLVM::LLVMArrayType::get(rowTy, rank)); + } + // opt-type-ptr: i8* (see fir.tdesc) + if (requiresExtendedDesc(ele)) { + parts.push_back(getExtendedDescFieldTypeModel<8>()(&getContext())); + auto rowTy = getExtendedDescFieldTypeModel<9>()(&getContext()); + parts.push_back(mlir::LLVM::LLVMArrayType::get(rowTy, 1)); + if (auto recTy = fir::unwrapSequenceType(ele).dyn_cast()) + if (recTy.getNumLenParams() > 0) { + // The descriptor design needs to be clarified regarding the number of + // length parameters in the addendum. Since it can change for + // polymorphic allocatables, it seems all length parameters cannot + // always possibly be placed in the addendum. + TODO_NOLOC("extended descriptor derived with length parameters"); + unsigned numLenParams = recTy.getNumLenParams(); + parts.push_back(mlir::LLVM::LLVMArrayType::get(rowTy, numLenParams)); + } + } + return mlir::LLVM::LLVMPointerType::get( + mlir::LLVM::LLVMStructType::getLiteral(&getContext(), parts, + /*isPacked=*/false)); + } + /// Convert fir.box type to the corresponding llvm struct type instead of a + /// pointer to this struct type. + mlir::Type convertBoxTypeAsStruct(BoxType box) { + return convertBoxType(box) + .cast() + .getElementType(); + } + + // fir.type --> llvm<"%name = { ty... }"> + mlir::Type convertRecordType(fir::RecordType derived) { + auto name = derived.getName(); + auto st = mlir::LLVM::LLVMStructType::getIdentified(&getContext(), name); + llvm::SmallVector members; + for (auto mem : derived.getTypeList()) { + // Prevent fir.box from degenerating to a pointer to a descriptor in the + // context of a record type. + if (auto box = mem.second.dyn_cast()) + members.push_back(convertBoxTypeAsStruct(box)); + else + members.push_back(convertType(mem.second).cast()); + } + if (mlir::succeeded(st.setBody(members, /*isPacked=*/false))) + return st; + return mlir::Type(); } template diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -167,3 +167,67 @@ func @test_unreachable() { fir.unreachable } + +// ----- + +// Test fir.extract_value operation conversion with derived type. + +func @extract_derived_type() -> f32 { + %0 = fir.undefined !fir.type + %1 = fir.extract_value %0, ["f", !fir.type] : (!fir.type) -> f32 + return %1 : f32 +} + +// CHECK-LABEL: llvm.func @extract_derived_type +// CHECK: %[[STRUCT:.*]] = llvm.mlir.undef : !llvm.struct<"derived", (f32)> +// CHECK: %[[VALUE:.*]] = llvm.extractvalue %[[STRUCT]][0 : i32] : !llvm.struct<"derived", (f32)> +// CHECK: llvm.return %[[VALUE]] : f32 + +// ----- + +// Test fir.extract_value operation conversion with a multi-dimensional array +// of tuple. + +func @extract_array(%a : !fir.array<10x10xtuple>) -> f32 { + %0 = fir.extract_value %a, [5 : index, 4 : index, 1 : index] : (!fir.array<10x10xtuple>) -> f32 + return %0 : f32 +} + +// CHECK-LABEL: llvm.func @extract_array( +// CHECK-SAME: %[[ARR:.*]]: !llvm.array<10 x array<10 x struct<(i32, f32)>>> +// CHECK: %[[VALUE:.*]] = llvm.extractvalue %[[ARR]][4 : index, 5 : index, 1 : index] : !llvm.array<10 x array<10 x struct<(i32, f32)>>> +// CHECK: llvm.return %[[VALUE]] : f32 + +// ----- + +// Test fir.insert_value operation conversion with a multi-dimensional array +// of tuple. + +func @extract_array(%a : !fir.array<10x10xtuple>) { + %f = arith.constant 2.0 : f32 + %i = arith.constant 1 : i32 + %0 = fir.insert_value %a, %i, [5 : index, 4 : index, 0 : index] : (!fir.array<10x10xtuple>, i32) -> !fir.array<10x10xtuple> + %1 = fir.insert_value %a, %f, [5 : index, 4 : index, 1 : index] : (!fir.array<10x10xtuple>, f32) -> !fir.array<10x10xtuple> + return +} + +// CHECK-LABEL: llvm.func @extract_array( +// CHECK-SAME: %[[ARR:.*]]: !llvm.array<10 x array<10 x struct<(i32, f32)>>> +// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %[[ARR]][4 : index, 5 : index, 0 : index] : !llvm.array<10 x array<10 x struct<(i32, f32)>>> +// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %[[ARR]][4 : index, 5 : index, 1 : index] : !llvm.array<10 x array<10 x struct<(i32, f32)>>> +// CHECK: llvm.return + +// ----- + +// Test fir.insert_value operation conversion with derived type. + +func @insert_tuple(%a : tuple) { + %f = arith.constant 2.0 : f32 + %1 = fir.insert_value %a, %f, [1 : index] : (tuple, f32) -> tuple + return +} + +// CHECK-LABEL: func @insert_tuple( +// CHECK-SAME: %[[TUPLE:.*]]: !llvm.struct<(i32, f32)> +// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %[[TUPLE]][1 : index] : !llvm.struct<(i32, f32)> +// CHECK: llvm.return