Index: flang/include/flang/Optimizer/Dialect/FIRAttr.h =================================================================== --- flang/include/flang/Optimizer/Dialect/FIRAttr.h +++ flang/include/flang/Optimizer/Dialect/FIRAttr.h @@ -5,11 +5,15 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// #ifndef OPTIMIZER_DIALECT_FIRATTR_H #define OPTIMIZER_DIALECT_FIRATTR_H -#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" namespace mlir { class DialectAsmParser; @@ -21,6 +25,7 @@ class FIROpsDialect; namespace detail { +struct OpaqueAttributeStorage; struct RealAttributeStorage; struct TypeAttributeStorage; } // namespace detail @@ -127,6 +132,21 @@ llvm::APFloat getValue() const; }; +/// An opaque attribute is used to provide dictionary lookups of pointers. The +/// underlying type of the pointee object is left up to the client. Opaque +/// attributes are always constructed as null pointers when parsing. +class OpaqueAttr + : public mlir::Attribute::AttrBase { +public: + using Base::Base; + + static constexpr llvm::StringRef getAttrName() { return "opaque"; } + static OpaqueAttr get(mlir::MLIRContext *ctxt, void *pointer); + + void *getPointer() const; +}; + mlir::Attribute parseFirAttribute(FIROpsDialect *dialect, mlir::DialectAsmParser &parser, mlir::Type type); Index: flang/include/flang/Optimizer/Dialect/FIRDialect.h =================================================================== --- flang/include/flang/Optimizer/Dialect/FIRDialect.h +++ flang/include/flang/Optimizer/Dialect/FIRDialect.h @@ -5,13 +5,22 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// #ifndef OPTIMIZER_DIALECT_FIRDIALECT_H #define OPTIMIZER_DIALECT_FIRDIALECT_H +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Affine/Passes.h" #include "mlir/IR/Dialect.h" #include "mlir/InitAllDialects.h" -#include "mlir/InitAllPasses.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/LocationSnapshot.h" +#include "mlir/Transforms/Passes.h" namespace fir { @@ -32,43 +41,59 @@ mlir::DialectAsmPrinter &p) const override; }; -/// Register the dialect with the provided registry. -inline void registerFIRDialects(mlir::DialectRegistry ®istry) { +/// The FIR codegen dialect is a dialect containing a small set of transient +/// operations used exclusively during code generation. +class FIRCodeGenDialect final : public mlir::Dialect { +public: + explicit FIRCodeGenDialect(mlir::MLIRContext *ctx); + virtual ~FIRCodeGenDialect(); + + static llvm::StringRef getDialectNamespace() { return "fircg"; } +}; + +/// Register and load all the dialects used by flang. +inline void registerAndLoadDialects(mlir::MLIRContext &ctx) { + auto registry = ctx.getDialectRegistry(); // clang-format off registry.insert(); + mlir::vector::VectorDialect>(); // clang-format on + registry.loadAll(&ctx); } /// Register the standard passes we use. This comes from registerAllPasses(), /// but is a smaller set since we aren't using many of the passes found there. inline void registerGeneralPasses() { - mlir::createCanonicalizerPass(); - mlir::createCSEPass(); - mlir::createSuperVectorizePass({}); - mlir::createLoopUnrollPass(); - mlir::createLoopUnrollAndJamPass(); - mlir::createSimplifyAffineStructuresPass(); - mlir::createLoopFusionPass(); - mlir::createLoopInvariantCodeMotionPass(); - mlir::createAffineLoopInvariantCodeMotionPass(); - mlir::createPipelineDataTransferPass(); - mlir::createLowerAffinePass(); - mlir::createLoopTilingPass(0); - mlir::createLoopCoalescingPass(); - mlir::createAffineDataCopyGenerationPass(0, 0); - mlir::createMemRefDataFlowOptPass(); - mlir::createStripDebugInfoPass(); - mlir::createPrintOpStatsPass(); - mlir::createInlinerPass(); - mlir::createSymbolDCEPass(); - mlir::createLocationSnapshotPass({}); + mlir::registerCanonicalizerPass(); + mlir::registerCSEPass(); + mlir::registerAffineLoopFusionPass(); + mlir::registerLoopInvariantCodeMotionPass(); + mlir::registerLoopCoalescingPass(); + mlir::registerStripDebugInfoPass(); + mlir::registerPrintOpStatsPass(); + mlir::registerInlinerPass(); + mlir::registerSCCPPass(); + mlir::registerMemRefDataFlowOptPass(); + mlir::registerSymbolDCEPass(); + mlir::registerLocationSnapshotPass(); + mlir::registerAffinePipelineDataTransferPass(); + + mlir::registerAffineVectorizePass(); + mlir::registerAffineLoopUnrollPass(); + mlir::registerAffineLoopUnrollAndJamPass(); + mlir::registerSimplifyAffineStructuresPass(); + mlir::registerAffineLoopInvariantCodeMotionPass(); + mlir::registerAffineLoopTilingPass(); + mlir::registerAffineDataCopyGenerationPass(); + + mlir::registerConvertAffineToStandardPass(); } inline void registerFIRPasses() { registerGeneralPasses(); } Index: flang/include/flang/Optimizer/Dialect/FIROps.h =================================================================== --- flang/include/flang/Optimizer/Dialect/FIROps.h +++ flang/include/flang/Optimizer/Dialect/FIROps.h @@ -18,7 +18,7 @@ namespace fir { class FirEndOp; -class LoopOp; +class DoLoopOp; class RealAttr; void buildCmpFOp(mlir::OpBuilder &builder, mlir::OperationState &result, @@ -29,7 +29,7 @@ mlir::Value rhs); unsigned getCaseArgumentOffset(llvm::ArrayRef cases, unsigned dest); -LoopOp getForInductionVarOwner(mlir::Value val); +DoLoopOp getForInductionVarOwner(mlir::Value val); bool isReferenceLike(mlir::Type type); mlir::ParseResult isValidCaseAttr(mlir::Attribute attr); mlir::ParseResult parseCmpfOp(mlir::OpAsmParser &parser, Index: flang/include/flang/Optimizer/Dialect/FIROps.td =================================================================== --- flang/include/flang/Optimizer/Dialect/FIROps.td +++ flang/include/flang/Optimizer/Dialect/FIROps.td @@ -15,9 +15,11 @@ #define FIR_DIALECT_FIR_OPS include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "flang/Optimizer/TypePredicates.td" def fir_Dialect : Dialect { let name = "fir"; @@ -32,23 +34,20 @@ // Fortran intrinsic types def fir_CharacterType : Type()">, "FIR character type">; -def fir_ComplexType : Type()">, +def fir_ComplexType : Type()">, "FIR complex type">; -def fir_IntegerType : Type()">, - "FIR integer type">; def fir_LogicalType : Type()">, "FIR logical type">; def fir_RealType : Type()">, "FIR real type">; +def fir_VectorType : Type()">, + "FIR vector type">; // Generalized FIR and standard dialect types representing intrinsic types -def AnyIntegerLike : TypeConstraint, "any integer">; def AnyLogicalLike : TypeConstraint, "any logical">; def AnyRealLike : TypeConstraint, "any real">; -def AnyIntegerType : Type; // Fortran derived (user defined) type def fir_RecordType : Type()">, @@ -61,26 +60,12 @@ // Composable types def AnyCompositeLike : TypeConstraint, "any composite">; + fir_VectorType.predicate, IsTupleTypePred, fir_CharacterType.predicate]>, + "any composite">; -// Reference to an entity type -def fir_ReferenceType : Type()">, - "reference type">; - -// Reference to an ALLOCATABLE attribute type -def fir_HeapType : Type()">, - "allocatable type">; - -// Reference to a POINTER attribute type -def fir_PointerType : Type()">, - "pointer type">; - -// Reference types -def AnyReferenceLike : TypeConstraint, "any reference">; - -// A descriptor tuple (captures a reference to an entity and other information) -def fir_BoxType : Type()">, "box type">; +// The legal types of global symbols +def AnyAddressableLike : TypeConstraint, "any addressable">; // CHARACTER type descriptor. A pair of a data reference and a LEN value. def fir_BoxCharType : Type()">, @@ -98,33 +83,23 @@ "any reference or box">; // A vector of Fortran triple notation describing a multidimensional array -def fir_DimsType : Type()">, "dim type">; -def AnyEmboxLike : TypeConstraint, - "any legal embox argument type">; -def AnyEmboxArg : Type; +def fir_ShapeType : Type()">, "shape type">; +def fir_ShapeShiftType : Type()">, + "shape shift type">; +def AnyShapeLike : TypeConstraint, "any legal shape type">; +def AnyShapeType : Type; +def fir_SliceType : Type()">, "slice type">; // A type descriptor's type def fir_TypeDescType : Type()">, "type desc type">; -// A field (in a RecordType) argument's type -def fir_FieldType : Type()">, "field type">; - -// A LEN parameter (in a RecordType) argument's type -def fir_LenType : Type()">, - "LEN parameter type">; - def AnyComponentLike : TypeConstraint, "any coordinate index">; def AnyComponentType : Type; -def AnyCoordinateLike : TypeConstraint, "any coordinate index">; -def AnyCoordinateType : Type; - // Base class for FIR operations. // All operations automatically get a prefix of "fir.". class fir_Op traits> @@ -140,10 +115,10 @@ } // Base builder for allocate operations -def fir_AllocateOpBuilder : - OpBuilderDAG<(ins "Type":$inType, CArg<"ValueRange", "{}">:$lenParams, - CArg<"ValueRange", "{}">:$sizes, - CArg<"ArrayRef", "{}">:$attributes), +def fir_AllocateOpBuilder : OpBuilderDAG<(ins "mlir::Type":$inType, + CArg<"mlir::ValueRange", "{}">:$lenParams, + CArg<"mlir::ValueRange", "{}">:$sizes, + CArg<"llvm::ArrayRef", "{}">:$attributes), [{ $_state.addTypes(getRefTy(inType)); $_state.addAttribute("in_type", TypeAttr::get(inType)); @@ -151,21 +126,22 @@ $_state.addAttributes(attributes); }]>; -def fir_NamedAllocateOpBuilder : - OpBuilderDAG<(ins "Type":$inType, "StringRef":$name, - CArg<"ValueRange", "{}">:$lenParams, CArg<"ValueRange", "{}">:$sizes, - CArg<"ArrayRef", "{}">:$attributes), +def fir_NamedAllocateOpBuilder : OpBuilderDAG<(ins "mlir::Type":$inType, + "llvm::StringRef":$name, CArg<"mlir::ValueRange", "{}">:$lenParams, + CArg<"mlir::ValueRange","{}">:$sizes, + CArg<"llvm::ArrayRef", "{}">:$attributes), [{ $_state.addTypes(getRefTy(inType)); $_state.addAttribute("in_type", TypeAttr::get(inType)); - $_state.addAttribute("name", $_builder.getStringAttr(name)); + if (!name.empty()) + $_state.addAttribute("name", $_builder.getStringAttr(name)); $_state.addOperands(sizes); $_state.addAttributes(attributes); }]>; -def fir_OneResultOpBuilder : - OpBuilderDAG<(ins "Type":$resultType, "ValueRange":$operands, - CArg<"ArrayRef", "{}">:$attributes), +def fir_OneResultOpBuilder : OpBuilderDAG<(ins "mlir::Type":$resultType, + "mlir::ValueRange":$operands, + CArg<"llvm::ArrayRef", "{}">:$attributes), [{ if (resultType) $_state.addTypes(resultType); @@ -197,9 +173,10 @@ ); } -class fir_AllocatableOp traits = []> : +class fir_AllocatableOp traits = []> : fir_AllocatableBaseOp])>, + !listconcat(traits, [MemoryEffects<[MemAlloc]>])>, fir_TwoBuilders, Arguments<(ins TypeAttr:$in_type, Variadic:$args)> { @@ -274,19 +251,19 @@ return val.getInt(); return 0; } - + operand_range getLenParams() { return {operand_begin(), operand_begin() + numLenParams()}; } - + unsigned numShapeOperands() { return operand_end() - operand_begin() + numLenParams(); } - + operand_range getShapeOperands() { return {operand_begin() + numLenParams(), operand_end()}; } - + static mlir::Type getRefTy(mlir::Type ty); /// Get the input type of the allocation @@ -309,7 +286,8 @@ // Memory SSA operations //===----------------------------------------------------------------------===// -def fir_AllocaOp : fir_AllocatableOp<"alloca"> { +def fir_AllocaOp : + fir_AllocatableOp<"alloca", AutomaticAllocationScopeResource> { let summary = "allocate storage for a temporary on the stack given a type"; let description = [{ This primitive operation is used to allocate an object on the stack. A @@ -336,6 +314,39 @@ (`len1`, `len2`) to the type `PT`. Finally, the operation is undefined if the ssa-value `%c` is negative. + + Fortran Semantics: + There is no language mechanism in Fortran to allocate space on the stack + like C's `alloca()` function. Therefore fir.alloca is not control-flow + dependent. However, the lifetime of a stack allocation is often limited to + a small region and a legal implementation may reuse stack storage in other + regions when there is no conflict. For example, take the following code + fragment. + + ```fortran + CALL foo(1) + CALL foo(2) + CALL foo(3) + ``` + + A legal implementation can allocate a stack slot and initialize it with the + constant `1`, then pass that by reference to foo. Likewise for the second + and third calls to foo, each stack slot being initialized accordingly. It is + also a conforming implementation to reuse the same stack slot for all three + calls, just initializing each in turn. This is possible as the lifetime of + the copy of each constant need not exceed that of the CALL statement. + Indeed, a user would likely expect a good Fortran compiler to perform such + an optimization. + + Until Fortran 2018, procedures defaulted to non-recursive. A legal + implementation could therefore convert stack allocations to global + allocations. Such a conversion effectively adds the SAVE attribute to all + variables. + + Some temporary entities (large arrays) probably should not be stack + allocated as stack space can often be limited. A legal implementation can + convert these large stack allocations to heap allocations regardless of + whether the procedure is recursive or not. }]; let results = (outs fir_ReferenceType); @@ -344,6 +355,8 @@ mlir::Type outType = getType(); if (!outType.isa()) return emitOpError("must be a !fir.ref type"); + if (fir::isa_unknown_size_box(fir::dyn_cast_ptrEleTy(outType))) + return emitOpError("cannot allocate !fir.box of unknown rank or type"); return mlir::success(); }]; @@ -352,7 +365,7 @@ }]; } -def fir_LoadOp : fir_OneResultOp<"load", [MemoryEffects<[MemRead]>]> { +def fir_LoadOp : fir_OneResultOp<"load"> { let summary = "load a value from a memory reference"; let description = [{ Load a value from a memory reference into an ssa-value (virtual register). @@ -368,19 +381,23 @@ or null. }]; - let arguments = (ins AnyReferenceLike:$memref); + let arguments = (ins Arg:$memref); - let builders = [ - OpBuilderDAG<(ins "Value":$refVal), + let builders = [OpBuilderDAG<(ins "mlir::Value":$refVal), [{ if (!refVal) { mlir::emitError($_state.location, "LoadOp has null argument"); return; } - auto refTy = refVal.getType().cast(); + auto eleTy = fir::dyn_cast_ptrEleTy(refVal.getType()); + if (!eleTy) { + mlir::emitError($_state.location, "not a memory reference type"); + return; + } $_state.addOperands(refVal); - $_state.addTypes(refTy.getEleTy()); - }]>]; + $_state.addTypes(eleTy); + }] + >]; let parser = [{ mlir::Type type; @@ -409,7 +426,7 @@ }]; } -def fir_StoreOp : fir_Op<"store", [MemoryEffects<[MemWrite]>]> { +def fir_StoreOp : fir_Op<"store", []> { let summary = "store an SSA-value to a memory location"; let description = [{ @@ -428,7 +445,8 @@ `%p`, is undefined or null. }]; - let arguments = (ins AnyType:$value, AnyReferenceLike:$memref); + let arguments = (ins AnyType:$value, + Arg:$memref); let parser = [{ mlir::Type type; @@ -458,6 +476,8 @@ let verifier = [{ if (value().getType() != fir::dyn_cast_ptrEleTy(memref().getType())) return emitOpError("store value type must match memory reference type"); + if (fir::isa_unknown_size_box(value().getType())) + return emitOpError("cannot store !fir.box of unknown rank or type"); return mlir::success(); }]; @@ -491,7 +511,25 @@ }]; } -def fir_AllocMemOp : fir_AllocatableOp<"allocmem"> { +def fir_ZeroOp : fir_OneResultOp<"zero_bits", [NoSideEffect]> { + let summary = "explicit polymorphic zero value of some type"; + let description = [{ + Constructs an ssa-value of the specified type with a value of zero for all + bits. + + ```mlir + %a = fir.zero_bits !fir.box>> + ``` + + The example creates a value of type box where all bits are zero. + }]; + + let results = (outs AnyType:$intype); + + let assemblyFormat = "type($intype) attr-dict"; +} + +def fir_AllocMemOp : fir_AllocatableOp<"allocmem", DefaultResource> { let summary = "allocate storage on the heap for an object of a given type"; let description = [{ @@ -512,6 +550,8 @@ mlir::Type outType = getType(); if (!outType.dyn_cast()) return emitOpError("must be a !fir.heap type"); + if (fir::isa_unknown_size_box(fir::dyn_cast_ptrEleTy(outType))) + return emitOpError("cannot allocate !fir.box of unknown rank or type"); return mlir::success(); }]; @@ -537,7 +577,7 @@ ``` }]; - let arguments = (ins fir_HeapType:$heapref); + let arguments = (ins Arg:$heapref); let assemblyFormat = "$heapref attr-dict `:` type($heapref)"; } @@ -605,11 +645,11 @@ list traits = []> : fir_SwitchTerminatorOp { let skipDefaultBuilders = 1; - let builders = [ - OpBuilderDAG<(ins "Value":$selector, "ArrayRef":$compareOperands, - "ArrayRef":$destinations, - CArg<"ArrayRef", "{}">:$destOperands, - CArg<"ArrayRef", "{}">:$attributes), + let builders = [OpBuilderDAG<(ins "mlir::Value":$selector, + "llvm::ArrayRef":$compareOperands, + "llvm::ArrayRef":$destinations, + CArg<"llvm::ArrayRef", "{}">:$destOperands, + CArg<"llvm::ArrayRef", "{}">:$attributes), [{ $_state.addOperands(selector); llvm::SmallVector ivalues; @@ -634,11 +674,12 @@ } } $_state.addAttribute(getOperandSegmentSizeAttr(), - $_builder.getI32VectorAttr({1, 0, sumArgs})); + $_builder.getI32VectorAttr({1, 0, sumArgs})); $_state.addAttribute(getTargetOffsetAttr(), - $_builder.getI32VectorAttr(argOffs)); + $_builder.getI32VectorAttr(argOffs)); $_state.addAttributes(attributes); - }]>]; + }] + >]; let parser = [{ mlir::OpAsmParser::OperandType selector; @@ -709,7 +750,7 @@ let verifier = [{ if (!(getSelector().getType().isa() || getSelector().getType().isa() || - getSelector().getType().isa())) + getSelector().getType().isa())) return emitOpError("must be an integer"); auto cases = (*this)->getAttrOfType(getCasesAttr()).getValue(); auto count = getNumDest(); @@ -793,16 +834,18 @@ let skipDefaultBuilders = 1; let builders = [ - OpBuilderDAG<(ins "Value":$selector, - "ArrayRef":$compareAttrs, - "ArrayRef":$cmpOperands, "ArrayRef":$destinations, - CArg<"ArrayRef", "{}">:$destOperands, - CArg<"ArrayRef", "{}">:$attributes)>, - OpBuilderDAG<(ins "Value":$selector, - "ArrayRef":$compareAttrs, "ArrayRef":$cmpOpList, - "ArrayRef":$destinations, - CArg<"ArrayRef", "{}">:$destOperands, - CArg<"ArrayRef", "{}">:$attributes)>]; + OpBuilderDAG<(ins "mlir::Value":$selector, + "llvm::ArrayRef":$compareAttrs, + "llvm::ArrayRef":$cmpOperands, + "llvm::ArrayRef":$destinations, + CArg<"llvm::ArrayRef", "{}">:$destOperands, + CArg<"llvm::ArrayRef", "{}">:$attributes)>, + OpBuilderDAG<(ins "mlir::Value":$selector, + "llvm::ArrayRef":$compareAttrs, + "llvm::ArrayRef":$cmpOpList, + "llvm::ArrayRef":$destinations, + CArg<"llvm::ArrayRef", "{}">:$destOperands, + CArg<"llvm::ArrayRef", "{}">:$attributes)>]; let parser = "return parseSelectCase(parser, result);"; @@ -835,7 +878,7 @@ let verifier = [{ if (!(getSelector().getType().isa() || getSelector().getType().isa() || - getSelector().getType().isa() || + getSelector().getType().isa() || getSelector().getType().isa() || getSelector().getType().isa())) return emitOpError("must be an integer, character, or logical"); @@ -887,15 +930,15 @@ }]; let skipDefaultBuilders = 1; - let builders = [ - OpBuilderDAG<(ins "Value":$selector, - "ArrayRef":$typeOperands, - "ArrayRef":$destinations, - CArg<"ArrayRef", "{}">:$destOperands, - CArg<"ArrayRef", "{}">:$attributes), + let builders = [OpBuilderDAG<(ins "mlir::Value":$selector, + "llvm::ArrayRef":$typeOperands, + "llvm::ArrayRef":$destinations, + CArg<"llvm::ArrayRef", "{}">:$destOperands, + CArg<"llvm::ArrayRef", "{}">:$attributes), [{ $_state.addOperands(selector); - $_state.addAttribute(getCasesAttr(), $_builder.getArrayAttr(typeOperands)); + $_state.addAttribute(getCasesAttr(), + $_builder.getArrayAttr(typeOperands)); const auto count = destinations.size(); for (auto d : destinations) $_state.addSuccessors(d); @@ -913,11 +956,12 @@ } } $_state.addAttribute(getOperandSegmentSizeAttr(), - $_builder.getI32VectorAttr({1, 0, sumArgs})); + $_builder.getI32VectorAttr({1, 0, sumArgs})); $_state.addAttribute(getTargetOffsetAttr(), - $_builder.getI32VectorAttr(argOffs)); + $_builder.getI32VectorAttr(argOffs)); $_state.addAttributes(attributes); - }]>]; + }] + >]; let parser = "return parseSelectType(parser, result);"; @@ -1016,9 +1060,11 @@ let assemblyFormat = "$resval attr-dict `:` type($resval)"; } +//===------------------------------------------------------------------------===// // Operations on !fir.box type objects +//===------------------------------------------------------------------------===// -def fir_EmboxOp : fir_Op<"embox", [NoSideEffect]> { +def fir_EmboxOp : fir_Op<"embox", [NoSideEffect, AttrSizedOperandSegments]> { let summary = "boxes a given reference and (optional) dimension information"; let description = [{ @@ -1030,85 +1076,84 @@ ```mlir %c1 = constant 1 : index %c10 = constant 10 : index - %4 = fir.dims(%c1, %c10, %c1) : (index, index, index) -> !fir.dims<1> %5 = ... : !fir.ref> - %6 = fir.embox %5, %4 : (!fir.ref>, !fir.dims<1>) -> !fir.box> + %6 = fir.embox %5 : (!fir.ref>) -> !fir.box> ``` The descriptor tuple may contain additional implementation-specific information through the use of additional attributes. }]; - let arguments = (ins AnyReferenceLike:$memref, Variadic:$args); + let arguments = (ins + AnyReferenceLike:$memref, + Optional:$shape, + Optional:$slice, + Variadic:$lenParams, + OptionalAttr:$accessMap + ); let results = (outs fir_BoxType); - let parser = "return parseEmboxOp(parser, result);"; + let builders = [ + OpBuilderDAG<(ins "llvm::ArrayRef":$resultTypes, + "mlir::Value":$memref, CArg<"mlir::Value", "{}">:$shape, + CArg<"mlir::Value", "{}">:$slice, + CArg<"mlir::ValueRange", "{}">:$lenParams), + [{ return build($_builder, $_state, resultTypes, memref, shape, slice, + lenParams, mlir::AffineMapAttr{}); }]> + ]; - let printer = [{ - p << getOperationName() << ' '; - p.printOperand(memref()); - if (hasLenParams()) { - p << '('; - p.printOperands(getLenParams()); - p << ')'; - } - if (getNumOperands() == 2) { - p << ", "; - p.printOperands(dims()); - } else if (auto map = (*this)->getAttr(layoutName())) { - p << " [" << map << ']'; - } - p.printOptionalAttrDict(getAttrs(), {layoutName(), lenpName()}); - p << " : "; - p.printFunctionalType(getOperation()); + let assemblyFormat = [{ + $memref (`(` $shape^ `)`)? (`[` $slice^ `]`)? (`typeparams` $lenParams^)? (`map` $accessMap^)? attr-dict `:` functional-type(operands, results) }]; let verifier = [{ + auto eleTy = fir::dyn_cast_ptrEleTy(memref().getType()); + if (!eleTy) + return emitOpError("must embox a memory reference type"); + bool isArray = false; + if (auto seqTy = eleTy.dyn_cast()) { + eleTy = seqTy.getEleTy(); + isArray = true; + } if (hasLenParams()) { - auto lenParams = numLenParams(); - auto eleTy = fir::dyn_cast_ptrEleTy(memref().getType()); - if (!eleTy) - return emitOpError("must embox a memory reference type"); + auto lenPs = numLenParams(); if (auto rt = eleTy.dyn_cast()) { - if (lenParams != rt.getNumLenParams()) + if (lenPs != rt.getNumLenParams()) return emitOpError("number of LEN params does not correspond" " to the !fir.type type"); + } else if (auto strTy = eleTy.dyn_cast()) { + if (strTy.getLen() != fir::CharacterType::unknownLen()) + return emitOpError("CHARACTER already has static LEN"); } else { - return emitOpError("LEN parameters require !fir.type type"); + return emitOpError("LEN parameters require CHARACTER or derived type"); } - for (auto lp : getLenParams()) - if (lp.getType().isa()) + for (auto lp : lenParams()) + if (!fir::isa_integer(lp.getType())) return emitOpError("LEN parameters must be integral type"); } - if (dims().size() == 0) { - // Ok. If there is no dims and no layout map, then emboxing a scalar. - // TODO: Should the type be enforced? It already must agree. - } else if (dims().size() == 1) { - auto d = *dims().begin(); - if (!d.getType().isa()) - return emitOpError("dimension argument must have !fir.dims type"); - } else { - return emitOpError("embox can only have one !fir.dim argument"); + if (getShape()) { + auto shapeTy = getShape().getType(); + if (!(shapeTy.isa() || shapeTy.isa())) + return emitOpError("must be shape or shapeshift type"); + if (!isArray) + return emitOpError("shape must not be provided for a scalar"); + } + if (getSlice()) { + auto sliceTy = getSlice().getType(); + if (!sliceTy.isa()) + return emitOpError("must be a slice type"); + if (!isArray) + return emitOpError("slice must not be provided for a scalar"); } return mlir::success(); }]; let extraClassDeclaration = [{ - static constexpr llvm::StringRef layoutName() { return "layout_map"; } - static constexpr llvm::StringRef lenpName() { return "len_param_count"; } - bool hasLenParams() { return bool{(*this)->getAttr(lenpName())}; } - unsigned numLenParams() { - if (auto x = (*this)->getAttrOfType(lenpName())) - return x.getInt(); - return 0; - } - operand_range getLenParams() { - return {operand_begin(), operand_begin() + numLenParams()}; - } - operand_range dims() { - return {operand_begin() + numLenParams() + 1, operand_end()}; - } + mlir::Value getShape() { return shape(); } + mlir::Value getSlice() { return slice(); } + bool hasLenParams() { return !lenParams().empty(); } + unsigned numLenParams() { return lenParams().size(); } }]; } @@ -1135,7 +1180,7 @@ }]; let arguments = (ins AnyReferenceLike:$memref, AnyIntegerLike:$len); - + let results = (outs fir_BoxCharType); let assemblyFormat = [{ @@ -1251,7 +1296,7 @@ ```mlir %40 = ... : !fir.box> - %41:6 = fir.unbox %40 : (!fir.box>) -> (!fir.ref>, i32, i32, !fir.tdesc>, i32, !fir.dims<4>) + %41:6 = fir.unbox %40 : (!fir.box>) -> (!fir.ref>, i32, i32, !fir.tdesc>, i32, !fir.array) ``` }]; @@ -1263,7 +1308,7 @@ AnyIntegerLike, // rank of data fir_TypeDescType, // abstract type descriptor AnyIntegerLike, // attribute flags (bitfields) - fir_DimsType // dimension information (if any) + fir_SequenceType // dimension information (if any) ); } @@ -1363,16 +1408,17 @@ ```mlir %c1 = constant 0 : i32 - %52:3 = fir.box_dims %40, %c1 : (!fir.box>, i32) -> (i32, i32, i32) + %52:3 = fir.box_dims %40, %c1 : (!fir.box>, i32) -> (index, index, index) ``` The above is a request to return the left most row (at index 0) triple from - the box. The triple will be the lower bound, upper bound, and stride. + the box. The triple will be the lower bound, extent, and byte-stride, which + are the values encoded in a standard descriptor. }]; let arguments = (ins fir_BoxType:$val, AnyIntegerLike:$dim); - let results = (outs AnyIntegerLike, AnyIntegerLike, AnyIntegerLike); + let results = (outs Index, Index, Index); let assemblyFormat = [{ $val `,` $dim attr-dict `:` functional-type(operands, results) @@ -1433,12 +1479,13 @@ let description = [{ Determine if the boxed value has a positive (> 0) rank. This will return true if the originating box value was from a fir.embox with a memory - reference value that had the type !fir.array and/or a dims argument. + reference value that had the type !fir.array and/or a shape argument. ```mlir %r = ... : !fir.ref - %d = fir.gendims(1, 100, 1) : (i32, i32, i32) -> !fir.dims<1> - %b = fir.embox %r, %d : (!fir.ref, !fir.dims<1>) -> !fir.box + %c_100 = constant 100 : index + %d = fir.shape(%c_100) : (index) -> !fir.shape<1> + %b = fir.embox %r shape %d : (!fir.ref, !fir.shape<1>) -> !fir.box %a = fir.box_isarray %b : (!fir.box) -> i1 // true ``` }]; @@ -1528,7 +1575,322 @@ let results = (outs fir_TypeDescType); } +//===----------------------------------------------------------------------===// +// Array value operations +//===----------------------------------------------------------------------===// + +def fir_ArrayLoadOp : fir_Op<"array_load", [AttrSizedOperandSegments]> { + + let summary = "Load an array as a value."; + + let description = [{ + Load an entire array as a single SSA value. + + ```fortran + real :: a(o:n,p:m) + ... + ... = ... a ... + ``` + + One can use `fir.array_load` to produce an ssa-value that captures an + immutable value of the entire array `a`, as in the Fortran array expression + shown above. Subsequent changes to the memory containing the array do not + alter its composite value. This operation let's one load an array as a + value while applying a runtime shape, shift, or slice to the memory + reference, and its semantics guarantee immutability. + + ```mlir + %s = fir.shape_shift %o, %n, %p, %m : (index, index, index, index) -> !fir.shape<2> + // load the entire array 'a' + %v = fir.array_load %a(%s) : (!fir.ref>, !fir.shape<2>) -> !fir.array + // a fir.store here into array %a does not change %v + ``` + }]; + + let arguments = (ins + Arg:$memref, + Optional:$shape, + Optional:$slice, + Variadic:$lenParams + ); + + let results = (outs fir_SequenceType); + + let assemblyFormat = [{ + $memref (`(`$shape^`)`)? (`[`$slice^`]`)? (`typeparams` $lenParams^)? attr-dict `:` functional-type(operands, results) + }]; + + let verifier = [{ + auto eleTy = fir::dyn_cast_ptrEleTy(memref().getType()); + if (!eleTy) + return emitOpError("must be a reference type"); + auto arrTy = eleTy.dyn_cast(); + if (!arrTy) + return emitOpError("must be a reference to an array"); + auto arrDim = arrTy.getDimension(); + + if (auto shapeOp = shape()) { + auto shapeTy = shapeOp.getType(); + unsigned shapeTyRank = 0; + if (auto s = shapeTy.dyn_cast()) { + shapeTyRank = s.getRank(); + } else { + auto ss = shapeTy.cast(); + shapeTyRank = ss.getRank(); + } + if (arrDim && arrDim != shapeTyRank) + return emitOpError("rank of dimension mismatched"); + } + + if (auto sliceOp = slice()) { + if (auto sliceTy = sliceOp.getType().dyn_cast()) { + if (sliceTy.getRank() != arrDim) + return emitOpError("rank of dimension in slice mismatched"); + } + } + + return mlir::success(); + }]; + + let extraClassDeclaration = [{ + std::vector getExtents(); + }]; +} + +def fir_ArrayFetchOp : fir_Op<"array_fetch", [NoSideEffect]> { + + let summary = "Fetch the value of an element of an array value"; + + let description = [{ + Fetch the value of an element in an array value. + + ```fortran + real :: a(n,m) + ... + ... a ... + ... a(r,s+1) ... + ``` + + One can use `fir.array_fetch` to fetch the (implied) value of `a(i,j)` in + an array expression as shown above. It can also be used to extract the + element `a(r,s+1)` in the second expression. + + ```mlir + %s = fir.shape %n, %m : (index, index) -> !fir.shape<2> + // load the entire array 'a' + %v = fir.array_load %a(%s) : (!fir.ref>, !fir.shape<2>) -> !fir.array + // fetch the value of one of the array value's elements + %1 = fir.array_fetch %v, %i, %j : (!fir.array, index, index) -> f32 + ``` + + It is only possible to use `array_fetch` on an `array_load` result value. + }]; + + let arguments = (ins + fir_SequenceType:$sequence, + Variadic:$indices + ); + + let results = (outs AnyType:$element); + + let assemblyFormat = [{ + $sequence `,` $indices attr-dict `:` functional-type(operands, results) + }]; + + let verifier = [{ + auto arrTy = sequence().getType().cast(); + if (indices().size() != arrTy.getDimension()) + return emitOpError("number of indices != dimension of array"); + if (element().getType() != arrTy.getEleTy()) + return emitOpError("return type does not match array"); + if (!isa(sequence().getDefiningOp())) + return emitOpError("argument #0 must be result of fir.array_load"); + return mlir::success(); + }]; +} + +def fir_ArrayUpdateOp : fir_Op<"array_update", [NoSideEffect]> { + + let summary = "Update the value of an element of an array value"; + + let description = [{ + Updates the value of an element in an array value. A new array value is + returned where all element values of the input array are identical except + for the selected element which is the value passed in the update. + + ```fortran + real :: a(n,m) + ... + a = ... + ``` + + One can use `fir.array_update` to update the (implied) value of `a(i,j)` + in an array expression as shown above. + + ```mlir + %s = fir.shape %n, %m : (index, index) -> !fir.shape<2> + // load the entire array 'a' + %v = fir.array_load %a(%s) : (!fir.ref>, !fir.shape<2>) -> !fir.array + // update the value of one of the array value's elements + // %r_{ij} = %f if (i,j) = (%i,%j), %v_{ij} otherwise + %r = fir.array_update %v, %f, %i, %j : (!fir.array, f32, index, index) -> !fir.array + fir.array_merge_store %v, %r to %a : !fir.ref> + ``` + + An array value update behaves as if a mapping function from the indices + to the new value has been added, replacing the previous mapping. These + mappings can be added to the ssa-value, but will not be materialized in + memory until the `fir.array_merge_store` is performed. + }]; + + let arguments = (ins + fir_SequenceType:$sequence, + AnyType:$merge, + Variadic:$indices + ); + + let results = (outs fir_SequenceType); + + let assemblyFormat = [{ + $sequence `,` $merge `,` $indices attr-dict `:` functional-type(operands, results) + }]; + + let verifier = [{ + auto arrTy = sequence().getType().cast(); + if (merge().getType() != arrTy.getEleTy()) + return emitOpError("merged value does not have element type"); + if (indices().size() != arrTy.getDimension()) + return emitOpError("number of indices != dimension of array"); + return mlir::success(); + }]; +} + +def fir_ArrayMergeStoreOp : fir_Op<"array_merge_store", [ + TypesMatchWith<"type of 'original' matches element type of 'memref'", + "memref", "original", + "fir::dyn_cast_ptrEleTy($_self)">, + TypesMatchWith<"type of 'sequence' matches element type of 'memref'", + "memref", "sequence", + "fir::dyn_cast_ptrEleTy($_self)">]> { + + let summary = "Store merged array value to memory."; + + let description = [{ + Store a merged array value to memory. + + ```fortran + real :: a(n,m) + ... + a = ... + ``` + + One can use `fir.array_merge_store` to merge/copy the value of `a` in an + array expression as shown above. + + ```mlir + %v = fir.array_load %a(%shape) : ... + %r = fir.array_update %v, %f, %i, %j : (!fir.array, f32, index, index) -> !fir.array + fir.array_merge_store %v, %r to %a : !fir.ref> + ``` + + This operation merges the original loaded array value, `%v`, with the + chained updates, `%r`, and stores the result to the array at address, `%a`. + }]; + + let arguments = (ins + fir_SequenceType:$original, + fir_SequenceType:$sequence, + Arg:$memref + ); + + let assemblyFormat = "$original `,` $sequence `to` $memref attr-dict `:` type($memref)"; + + let verifier = [{ + if (!isa(original().getDefiningOp())) + return emitOpError("operand #0 must be result of a fir.array_load op"); + return mlir::success(); + }]; +} + +//===----------------------------------------------------------------------===// // Record and array type operations +//===----------------------------------------------------------------------===// + +def fir_ArrayCoorOp : fir_Op<"array_coor", + [NoSideEffect, AttrSizedOperandSegments]> { + + let summary = "Find the coordinate of an element of an array"; + + let description = [{ + Compute the location of an element in an array when the shape of the + array is only known at runtime. + + This operation is intended to capture all the runtime values needed to + compute the address of an array reference in a single high-level op. Given + the following Fortran input: + + ```fortran + real :: a(n,m) + ... + ... a(i,j) ... + ``` + + One can use `fir.array_coor` to determine the address of `a(i,j)`. + + ```mlir + %s = fir.shape %n, %m : (index, index) -> !fir.shape<2> + %1 = fir.array_coor %a(%s) %i, %j : (!fir.ref>, !fir.shape<2>, index, index) -> !fir.ref + ``` + }]; + + let arguments = (ins + AnyReferenceLike:$memref, + Optional:$shape, + Optional:$slice, + Variadic:$indices, + Variadic:$lenParams + ); + + let results = (outs fir_ReferenceType); + + let assemblyFormat = [{ + $memref (`(`$shape^`)`)? (`[`$slice^`]`)? $indices (`typeparams` $lenParams^)? attr-dict `:` functional-type(operands, results) + }]; + + let verifier = [{ + auto eleTy = fir::dyn_cast_ptrEleTy(memref().getType()); + if (!eleTy) + return emitOpError("must be a reference type"); + auto arrTy = eleTy.dyn_cast(); + if (!arrTy) + return emitOpError("must be a reference to an array"); + auto arrDim = arrTy.getDimension(); + + if (auto shapeOp = shape()) { + auto shapeTy = shapeOp.getType(); + unsigned shapeTyRank = 0; + if (auto s = shapeTy.dyn_cast()) { + shapeTyRank = s.getRank(); + } else { + auto ss = shapeTy.cast(); + shapeTyRank = ss.getRank(); + } + if (arrDim && arrDim != shapeTyRank) + return emitOpError("rank of dimension mismatched"); + if (shapeTyRank != indices().size()) + return emitOpError("number of indices do not match dim rank"); + } + + if (auto sliceOp = slice()) { + if (auto sliceTy = sliceOp.getType().dyn_cast()) { + if (sliceTy.getRank() != arrDim) + return emitOpError("rank of dimension in slice mismatched"); + } + } + + return mlir::success(); + }]; +} def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> { let summary = "Finds the coordinate (location) of a value in memory"; @@ -1567,7 +1929,7 @@ p.printFunctionalType((*this)->getOperandTypes(), (*this)->getResultTypes()); }]; - + let verifier = [{ auto refTy = ref().getType(); if (fir::isa_ref_type(refTy)) { @@ -1578,13 +1940,16 @@ if (arrTy.getConstantRows() < arrTy.getDimension() - 1) return emitOpError("cannot find coordinate with unknown extents"); } + if (!(fir::isa_aggregate(eleTy) || fir::isa_complex(eleTy) || + fir::isa_char_string(eleTy))) + return emitOpError("cannot apply coordinate_of to this type"); } // Recovering a LEN type parameter only makes sense from a boxed value for (auto co : coor()) if (dyn_cast_or_null(co.getDefiningOp())) { if (getNumOperands() != 2) return emitOpError("len_param_index must be last argument"); - if (!ref().getType().dyn_cast()) + if (!ref().getType().isa()) return emitOpError("len_param_index must be used on box type"); } if (auto attr = (*this)->getAttr(CoordinateOp::baseType())) { @@ -1598,11 +1963,12 @@ let skipDefaultBuilders = 1; let builders = [ - OpBuilderDAG<(ins "Type":$type, "Value":$ref, "ValueRange":$coor, - CArg<"ArrayRef", "{}">:$attrs)>, - OpBuilderDAG<(ins "Type":$type, "ValueRange":$operands, - CArg<"ArrayRef", "{}">:$attrs)>]; - + OpBuilderDAG<(ins "mlir::Type":$type, "mlir::Value":$ref, + "mlir::ValueRange":$coor, + CArg<"llvm::ArrayRef", "{}">:$attrs)>, + OpBuilderDAG<(ins "mlir::Type":$type, "mlir::ValueRange":$operands, + CArg<"llvm::ArrayRef", "{}">:$attrs)>]; + let extraClassDeclaration = [{ static constexpr llvm::StringRef baseType() { return "base_type"; } mlir::Type getBaseType(); @@ -1616,6 +1982,8 @@ Extract a value from an entity with a type composed of tuples, arrays, and/or derived types. Returns the value from entity with the type of the specified component. Cannot be used on values of `!fir.box` type. + It can also be used to access complex parts and elements of a character + string. Note that the entity ssa-value must be of compile-time known size in order to use this operation. @@ -1707,53 +2075,172 @@ } }]; - let builders = [ - OpBuilderDAG<(ins "StringRef":$fieldName, "Type":$recTy, - CArg<"ValueRange", "{}">:$operands), + let builders = [OpBuilderDAG<(ins "llvm::StringRef":$fieldName, + "mlir::Type":$recTy, CArg<"mlir::ValueRange","{}">:$operands), [{ - $_state.addAttribute(fieldAttrName(), $_builder.getStringAttr(fieldName)); + $_state.addAttribute(fieldAttrName(), + $_builder.getStringAttr(fieldName)); $_state.addAttribute(typeAttrName(), TypeAttr::get(recTy)); $_state.addOperands(operands); - }]>]; + }] + >]; let extraClassDeclaration = [{ static constexpr llvm::StringRef fieldAttrName() { return "field_id"; } static constexpr llvm::StringRef typeAttrName() { return "on_type"; } + llvm::StringRef getFieldName() { return field_id(); } }]; } -def fir_GenDimsOp : fir_OneResultOp<"gendims", [NoSideEffect]> { +def fir_ShapeOp : fir_Op<"shape", [NoSideEffect]> { - let summary = "generate a value of type `!fir.dims`"; + let summary = "generate an abstract shape vector of type `!fir.shape`"; let description = [{ - The arguments are an ordered list of integral type values that is a - multiple of 3 in length. Each such triple is defined as: the lower - index, the extent, and the stride for that dimension. The dimension - information is given in the same row-to-column order as Fortran. This - abstract dimension value must describe a reified object, so all dimension - information must be specified. The extent must be nonnegative and the - stride must not be zero. + The arguments are an ordered list of integral type values that define the + runtime extent of each dimension of an array. The shape information is + given in the same row-to-column order as Fortran. This abstract shape value + must be applied to a reified object, so all shape information must be + specified. The extent must be nonnegative. ```mlir - %d = fir.gendims %lo, %ext, %str : (index, index, index) -> !fir.dims<1> + %d = fir.shape %row_sz, %col_sz : (index, index) -> !fir.shape<2> ``` }]; - let arguments = (ins Variadic:$triples); + let arguments = (ins Variadic:$extents); - let results = (outs fir_DimsType); + let results = (outs fir_ShapeType); let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) }]; let verifier = [{ - auto size = triples().size(); - if (size < 1 || size > 16 * 3) + auto size = extents().size(); + auto shapeTy = getType().dyn_cast(); + assert(shapeTy && "must be a shape type"); + if (shapeTy.getRank() != size) + return emitOpError("shape type rank mismatch"); + return mlir::success(); + }]; + + let extraClassDeclaration = [{ + std::vector getExtents() { + return {extents().begin(), extents().end()}; + } + }]; +} + +def fir_ShapeShiftOp : fir_Op<"shape_shift", [NoSideEffect]> { + + let summary = [{ + generate an abstract shape and shift vector of type `!fir.shapeshift` + }]; + + let description = [{ + The arguments are an ordered list of integral type values that is a multiple + of 2 in length. Each such pair is defined as: the lower bound and the + extent for that dimension. The shifted shape information is given in the + same row-to-column order as Fortran. This abstract shifted shape value must + be applied to a reified object, so all shifted shape information must be + specified. The extent must be nonnegative. + + ```mlir + %d = fir.shape_shift %lo, %extent : (index, index) -> !fir.shapeshift<1> + ``` + }]; + + let arguments = (ins Variadic:$pairs); + + let results = (outs fir_ShapeShiftType); + + let assemblyFormat = [{ + operands attr-dict `:` functional-type(operands, results) + }]; + + let verifier = [{ + auto size = pairs().size(); + if (size < 2 || size > 16 * 2) return emitOpError("incorrect number of args"); + if (size % 2 != 0) + return emitOpError("requires a multiple of 2 args"); + auto shapeTy = getType().dyn_cast(); + assert(shapeTy && "must be a shape shift type"); + if (shapeTy.getRank() * 2 != size) + return emitOpError("shape type rank mismatch"); + return mlir::success(); + }]; + + let extraClassDeclaration = [{ + // Logically unzip the origins from the extent values. + std::vector getOrigins() { + std::vector result; + for (auto i : llvm::enumerate(pairs())) + if (!(i.index() & 1)) + result.push_back(i.value()); + return result; + } + + // Logically unzip the extents from the origin values. + std::vector getExtents() { + std::vector result; + for (auto i : llvm::enumerate(pairs())) + if (i.index() & 1) + result.push_back(i.value()); + return result; + } + }]; +} + +def fir_SliceOp : fir_Op<"slice", [NoSideEffect, AttrSizedOperandSegments]> { + + let summary = "generate an abstract slice vector of type `!fir.slice`"; + + let description = [{ + The array slicing arguments are an ordered list of integral type values + that must be a multiple of 3 in length. Each such triple is defined as: + the lower bound, the upper bound, and the stride for that dimension, as in + Fortran syntax. Both bounds are inclusive. The array slice information is + given in the same row-to-column order as Fortran. This abstract slice value + must be applied to a reified object, so all slice information must be + specified. The extent must be nonnegative and the stride must not be zero. + + ```mlir + %d = fir.slice %lo, %hi, %step : (index, index, index) -> !fir.slice<1> + ``` + + To support generalized slicing of Fortran's dynamic derived types, a slice + op can be given a component path (narrowing from the product type of the + original array to the specific elemental type of the sliced projection). + + ```mlir + %fld = fir.field_index component, !fir.type + %d = fir.slice %lo, %hi, %step path %fld : (index, index, index, !fir.field) -> !fir.slice<1> + ``` + }]; + + let arguments = (ins + Variadic:$triples, + Variadic:$fields + ); + + let results = (outs fir_SliceType); + + let assemblyFormat = [{ + $triples (`path` $fields^)? attr-dict `:` functional-type(operands, results) + }]; + + let verifier = [{ + auto size = triples().size(); + if (size < 3 || size > 16 * 3) + return emitOpError("incorrect number of args for triple"); if (size % 3 != 0) return emitOpError("requires a multiple of 3 args"); + auto sliceTy = getType().dyn_cast(); + assert(sliceTy && "must be a slice type"); + if (sliceTy.getRank() * 3 != size) + return emitOpError("slice type rank mismatch"); return mlir::success(); }]; } @@ -1762,9 +2249,11 @@ let summary = "insert a new sub-value into a copy of an existing aggregate"; let description = [{ - Insert a value from an entity with a type composed of tuples, arrays, + Insert a value into an entity with a type composed of tuples, arrays, and/or derived types. Returns a new ssa value with the same type as the original entity. Cannot be used on values of `!fir.box` type. + It can also be used to set complex parts and elements of a character + string. Note that the entity ssa-value must be of compile-time known size in order to use this operation. @@ -1785,6 +2274,26 @@ let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) }]; + + let hasCanonicalizer = 1; +} + +def fir_InsertOnRangeOp : fir_OneResultOp<"insert_on_range", [NoSideEffect]> { + let summary = "insert sub-value into a range on an existing sequence"; + + let description = [{ + Insert a constant value into an entity with an array type. Returns a + new ssa value where the range of offsets from the original array have been + replaced with the constant. The result is an array type entity. + }]; + + let arguments = (ins fir_SequenceType:$seq, AnyType:$val, + Variadic:$coor); + let results = (outs fir_SequenceType); + + let assemblyFormat = [{ + operands attr-dict `:` functional-type(operands, results) + }]; } def fir_LenParamIndexOp : fir_OneResultOp<"len_param_index", [NoSideEffect]> { @@ -1830,12 +2339,13 @@ << ", " << (*this)->getAttr(typeAttrName()); }]; - let builders = [ - OpBuilderDAG<(ins "StringRef":$fieldName, "Type":$recTy), + let builders = [OpBuilderDAG<(ins "llvm::StringRef":$fieldName, + "mlir::Type":$recTy), [{ $_state.addAttribute(fieldAttrName(), $_builder.getStringAttr(fieldName)); $_state.addAttribute(typeAttrName(), TypeAttr::get(recTy)); - }]>]; + }] + >]; let extraClassDeclaration = [{ static constexpr llvm::StringRef fieldAttrName() { return "field_id"; } @@ -1852,7 +2362,7 @@ def fir_ResultOp : fir_Op<"result", [NoSideEffect, ReturnLike, Terminator, - ParentOneOf<["WhereOp", "LoopOp", "IterWhileOp"]>]> { + ParentOneOf<["IfOp", "DoLoopOp", "IterWhileOp"]>]> { let summary = "special terminator for use in fir region operations"; let description = [{ @@ -1863,10 +2373,7 @@ }]; let arguments = (ins Variadic:$results); - let builders = [ - OpBuilderDAG<(ins), - [{/* do nothing */}]> - ]; + let builders = [OpBuilderDAG<(ins), [{ /* do nothing */ }]>]; let assemblyFormat = "($results^ `:` type($results))? attr-dict"; @@ -1883,7 +2390,7 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } -def fir_LoopOp : region_Op<"do_loop", +def fir_DoLoopOp : region_Op<"do_loop", [DeclareOpInterfaceMethods]> { let summary = "generalized loop operation"; let description = [{ @@ -1897,7 +2404,7 @@ fir.do_loop %i = %l to %u step %s unordered { %x = fir.convert %i : (index) -> i32 %v = fir.call @compute(%x) : (i32) -> f32 - %p = fir.coordinate_of %A, %i : (!fir.ref, index) -> !fir.ref + %p = fir.coordinate_of %A, %i : (!fir.ref>, index) -> !fir.ref fir.store %v to %p : !fir.ref } ``` @@ -1911,23 +2418,26 @@ Index:$upperBound, Index:$step, Variadic:$initArgs, - OptionalAttr:$unordered - ); - let results = (outs - Variadic:$results + OptionalAttr:$unordered, + OptionalAttr:$finalValue ); + let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); let skipDefaultBuilders = 1; let builders = [ OpBuilderDAG<(ins "mlir::Value":$lowerBound, "mlir::Value":$upperBound, "mlir::Value":$step, CArg<"bool", "false">:$unordered, - CArg<"ValueRange", "llvm::None">:$iterArgs, - CArg<"ArrayRef", "{}">:$attributes)> + CArg<"bool", "false">:$finalCountValue, + CArg<"mlir::ValueRange", "llvm::None">:$iterArgs, + CArg<"llvm::ArrayRef", "{}">:$attributes)> ]; let extraClassDeclaration = [{ static constexpr llvm::StringRef unorderedAttrName() { return "unordered"; } + static constexpr llvm::StringRef finalValueAttrName() { + return "finalValue"; + } mlir::Value getInductionVar() { return getBody()->getArgument(0); } mlir::OpBuilder getBodyBuilder() { @@ -1966,10 +2476,15 @@ (*this)->setAttr(unorderedAttrName(), mlir::UnitAttr::get(getContext())); } + + mlir::BlockArgument iterArgToBlockArg(mlir::Value iterArg); + void resultToSourceOps(llvm::SmallVectorImpl &results, + unsigned resultNum); + mlir::Value blockArgToSourceOp(unsigned blockArgNum); }]; } -def fir_WhereOp : region_Op<"if", [NoRegionArguments]> { +def fir_IfOp : region_Op<"if", [NoRegionArguments]> { let summary = "if-then-else conditional operation"; let description = [{ Used to conditionally execute operations. This operation is the FIR @@ -1990,28 +2505,31 @@ let results = (outs Variadic:$results); let regions = (region - SizedRegion<1>:$whereRegion, - AnyRegion:$otherRegion + SizedRegion<1>:$thenRegion, + AnyRegion:$elseRegion ); let skipDefaultBuilders = 1; let builders = [ - OpBuilderDAG<(ins "Value":$cond, "bool":$withOtherRegion)>, - OpBuilderDAG<(ins "TypeRange":$resultTypes, "Value":$cond, - "bool":$withOtherRegion)> + OpBuilderDAG<(ins "mlir::Value":$cond, "bool":$withElseRegion)>, + OpBuilderDAG<(ins "mlir::TypeRange":$resultTypes, "mlir::Value":$cond, + "bool":$withElseRegion)> ]; let extraClassDeclaration = [{ - mlir::OpBuilder getWhereBodyBuilder() { - assert(!whereRegion().empty() && "Unexpected empty 'where' region."); - mlir::Block &body = whereRegion().front(); + mlir::OpBuilder getThenBodyBuilder() { + assert(!thenRegion().empty() && "Unexpected empty 'where' region."); + mlir::Block &body = thenRegion().front(); return mlir::OpBuilder(&body, std::prev(body.end())); } - mlir::OpBuilder getOtherBodyBuilder() { - assert(!otherRegion().empty() && "Unexpected empty 'other' region."); - mlir::Block &body = otherRegion().front(); + mlir::OpBuilder getElseBodyBuilder() { + assert(!elseRegion().empty() && "Unexpected empty 'other' region."); + mlir::Block &body = elseRegion().front(); return mlir::OpBuilder(&body, std::prev(body.end())); } + + void resultToSourceOps(llvm::SmallVectorImpl &results, + unsigned resultNum); }]; } @@ -2019,12 +2537,30 @@ [DeclareOpInterfaceMethods]> { let summary = "DO loop with early exit condition"; let description = [{ - This construct is useful for lowering implied-DO loops. It is very similar - to `fir::LoopOp` with the addition that it requires a single loop-carried - bool value that signals an early exit condition to the operation. A `true` - disposition means the next loop iteration should proceed. A `false` - indicates that the `fir.iterate_while` operation should terminate and - return its iteration arguments. + This single-entry, single-exit looping construct is useful for lowering + counted loops that can exit early such as, for instance, implied-DO loops. + It is very similar to `fir::DoLoopOp` with the addition that it requires + a single loop-carried bool value that signals an early exit condition to + the operation. A `true` disposition means the next loop iteration should + proceed. A `false` indicates that the `fir.iterate_while` operation should + terminate and return its iteration arguments. This is a degenerate counted + loop in that the loop is not guaranteed to execute all iterations. + + An example iterate_while that returns the counter value, the early + termination condition, and an extra loop-carried value is shown here. This + loop counts from %lo to %up (inclusive), stepping by %c1, so long as the + early exit (%ok) is true. The iter_args %sh value is also carried by the + loop. The result triple is the values of %i=phi(%lo,%i+%c1), + %ok=phi(%okIn,%okNew), and %sh=phi(%shIn,%shNew) from the last executed + iteration. + + ```mlir + %v:3 = fir.iterate_while (%i = %lo to %up step %c1) and (%ok = %okIn) iter_args(%sh = %shIn) -> (index, i1, i16) { + %shNew = fir.call @bar(%sh) : (i16) -> i16 + %okNew = fir.call @foo(%sh) : (i16) -> i1 + fir.result %i, %okNew, %shNew : index, i1, i16 + } + ``` }]; let arguments = (ins @@ -2032,23 +2568,25 @@ Index:$upperBound, Index:$step, I1:$iterateIn, - Variadic:$initArgs - ); - let results = (outs - I1:$iterateResult, - Variadic:$results + Variadic:$initArgs, + OptionalAttr:$finalValue ); + let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); let skipDefaultBuilders = 1; let builders = [ OpBuilderDAG<(ins "mlir::Value":$lowerBound, "mlir::Value":$upperBound, "mlir::Value":$step, "mlir::Value":$iterate, - CArg<"ValueRange", "llvm::None">:$iterArgs, - CArg<"ArrayRef", "{}">:$attributes)> + CArg<"bool", "false">:$finalCountValue, + CArg<"mlir::ValueRange", "llvm::None">:$iterArgs, + CArg<"llvm::ArrayRef", "{}">:$attributes)> ]; - + let extraClassDeclaration = [{ + static constexpr llvm::StringRef finalValueAttrName() { + return "finalValue"; + } mlir::Block *getBody() { return ®ion().front(); } mlir::Value getIterateVar() { return getBody()->getArgument(1); } mlir::Value getInductionVar() { return getBody()->getArgument(0); } @@ -2080,6 +2618,11 @@ unsigned getNumIterOperands() { return (*this)->getNumOperands() - getNumControlOperands(); } + + mlir::BlockArgument iterArgToBlockArg(mlir::Value iterArg); + void resultToSourceOps(llvm::SmallVectorImpl &results, + unsigned resultNum); + mlir::Value blockArgToSourceOp(unsigned blockArgNum); }]; } @@ -2087,8 +2630,7 @@ // Procedure call operations //===----------------------------------------------------------------------===// -def fir_CallOp : fir_Op<"call", - [MemoryEffects<[MemAlloc, MemFree, MemRead, MemWrite]>]> { +def fir_CallOp : fir_Op<"call", [CallOpInterface]> { let summary = "call a procedure"; let description = [{ @@ -2107,19 +2649,63 @@ OptionalAttr:$callee, Variadic:$args ); - let results = (outs Variadic); let parser = "return parseCallOp(parser, result);"; let printer = "printCallOp(p, *this);"; + let builders = [ + OpBuilderDAG<(ins "mlir::FuncOp":$callee, + CArg<"mlir::ValueRange", "{}">:$operands), + [{ + $_state.addOperands(operands); + $_state.addAttribute(calleeAttrName(), + $_builder.getSymbolRefAttr(callee)); + $_state.addTypes(callee.getType().getResults()); + }]>, + OpBuilderDAG<(ins "mlir::SymbolRefAttr":$callee, + "llvm::ArrayRef":$results, + CArg<"mlir::ValueRange", "{}">:$operands), + [{ + $_state.addOperands(operands); + $_state.addAttribute(calleeAttrName(), callee); + $_state.addTypes(results); + }]>, + OpBuilderDAG<(ins "llvm::StringRef":$callee, + "llvm::ArrayRef":$results, + CArg<"mlir::ValueRange", "{}">:$operands), + [{ + build($_builder, $_state, $_builder.getSymbolRefAttr(callee), results, + operands); + }]>]; + let extraClassDeclaration = [{ static constexpr StringRef calleeAttrName() { return "callee"; } + + mlir::FunctionType getFunctionType(); + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + if (auto calling = + (*this)->getAttrOfType(calleeAttrName())) + return {arg_operand_begin(), arg_operand_end()}; + return {arg_operand_begin() + 1, arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + /// Return the callee of this operation. + CallInterfaceCallable getCallableForCallee() { + if (auto calling = + (*this)->getAttrOfType(calleeAttrName())) + return calling; + return getOperand(0); + } }]; } -def fir_DispatchOp : fir_Op<"dispatch", - [MemoryEffects<[MemAlloc, MemFree, MemRead, MemWrite]>]> { +def fir_DispatchOp : fir_Op<"dispatch", []> { let summary = "call a type-bound procedure"; let description = [{ @@ -2147,10 +2733,11 @@ llvm::StringRef calleeName; if (failed(parser.parseOptionalKeyword(&calleeName))) { mlir::StringAttr calleeAttr; - if (parser.parseAttribute(calleeAttr, "method", result.attributes)) + if (parser.parseAttribute(calleeAttr, methodAttrName(), + result.attributes)) return mlir::failure(); } else { - result.addAttribute("method", + result.addAttribute(methodAttrName(), parser.getBuilder().getStringAttr(calleeName)); } if (parser.parseOperandList(operands, @@ -2161,22 +2748,19 @@ parser.resolveOperands( operands, calleeType.getInputs(), calleeLoc, result.operands)) return mlir::failure(); - result.addAttribute("fn_type", mlir::TypeAttr::get(calleeType)); return mlir::success(); }]; let printer = [{ - p << getOperationName() << ' ' << (*this)->getAttr("method") << '('; + p << getOperationName() << ' ' << (*this)->getAttr(methodAttrName()) << '('; p.printOperand(object()); - if (arg_operand_begin() != arg_operand_end()) { + if (!args().empty()) { p << ", "; p.printOperands(args()); } - p << ')'; - p.printOptionalAttrDict(getAttrs(), {"fn_type", "method"}); - auto resTy{getResultTypes()}; - llvm::SmallVector argTy(getOperandTypes()); - p << " : " << mlir::FunctionType::get(getContext(), argTy, resTy); + p << ") : "; + p.printFunctionalType((*this)->getOperandTypes(), + (*this)->getResultTypes()); }]; let extraClassDeclaration = [{ @@ -2184,9 +2768,13 @@ operand_range getArgOperands() { return {arg_operand_begin(), arg_operand_end()}; } + // operand[0] is the object (of box type) operand_iterator arg_operand_begin() { return operand_begin() + 1; } operand_iterator arg_operand_end() { return operand_end(); } - llvm::StringRef passArgAttrName() { return "pass_arg_pos"; } + static constexpr llvm::StringRef passArgAttrName() { + return "pass_arg_pos"; + } + static constexpr llvm::StringRef methodAttrName() { return "method"; } unsigned passArgPos(); }]; } @@ -2207,7 +2795,7 @@ ``` }]; - let results = (outs fir_SequenceType); + let results = (outs fir_CharacterType); let parser = [{ auto &builder = parser.getBuilder(); @@ -2229,10 +2817,11 @@ parser.parseRParen() || parser.parseColonType(type)) return mlir::failure(); - if (!(type.isa() || type.isa())) + auto charTy = type.dyn_cast(); + if (!charTy) return parser.emitError(parser.getCurrentLocation(), "must have character type"); - type = fir::SequenceType::get({sz.getInt()}, type); + type = fir::CharacterType::get(builder.getContext(), charTy.getFKind(), sz.getInt()); if (!type || parser.addTypesToList(type, result.types)) return mlir::failure(); return mlir::success(); @@ -2241,15 +2830,12 @@ let printer = [{ p << getOperationName() << ' ' << getValue() << '('; p << getSize().cast().getValue() << ") : "; - p.printType(getType().cast().getEleTy()); + p.printType(getType()); }]; let verifier = [{ if (getSize().cast().getValue().isNegative()) return emitOpError("size must be non-negative"); - auto eleTy = getType().cast().getEleTy(); - if (!eleTy.isa()) - return emitOpError("must have !fir.char type"); if (auto xl = (*this)->getAttr(xlist())) { auto xList = xl.cast(); for (auto a : xList) @@ -2316,11 +2902,11 @@ }]; let arguments = (ins FirRealAttr:$constant); - + let results = (outs fir_RealType:$res); let assemblyFormat = "`(` $constant `)` attr-dict `:` type($res)"; - + let verifier = [{ if (!getType().isa()) return emitOpError("must be a !fir.real type"); @@ -2364,9 +2950,8 @@ let results = (outs AnyLogicalLike); - let builders = [ - OpBuilderDAG<(ins "CmpFPredicate":$predicate, "Value":$lhs, "Value":$rhs), - [{ + let builders = [OpBuilderDAG<(ins "mlir::CmpFPredicate":$predicate, + "mlir::Value":$lhs, "mlir::Value":$rhs), [{ buildCmpFOp($_builder, $_state, predicate, lhs, rhs); }]>]; @@ -2397,7 +2982,7 @@ }]; let results = (outs fir_ComplexType); - + let parser = [{ fir::RealAttr realp; fir::RealAttr imagp; @@ -2427,7 +3012,7 @@ }]; let verifier = [{ - if (!getType().isa()) + if (!getType().isa()) return emitOpError("must be a !fir.complex type"); return mlir::success(); }]; @@ -2473,9 +3058,8 @@ let printer = "printCmpcOp(p, *this);"; - let builders = [ - OpBuilderDAG<(ins "CmpFPredicate":$predicate, "Value":$lhs, "Value":$rhs), - [{ + let builders = [OpBuilderDAG<(ins "mlir::CmpFPredicate":$predicate, + "mlir::Value":$lhs, "mlir::Value":$rhs), [{ buildCmpCOp($_builder, $_state, predicate, lhs, rhs); }]>]; @@ -2495,10 +3079,11 @@ def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> { let summary = "convert a symbol to an SSA value"; - + let description = [{ Convert a symbol (a function or global reference) to an SSA-value to be - used in other Operations. + used in other Operations. References to Fortran symbols are distinguished + via this operation from other arbitrary constant values. ```mlir %p = fir.address_of(@symbol) : !fir.ref @@ -2507,14 +3092,14 @@ let arguments = (ins SymbolRefAttr:$symbol); - let results = (outs fir_ReferenceType:$resTy); + let results = (outs AnyAddressableLike:$resTy); let assemblyFormat = "`(` $symbol `)` attr-dict `:` type($resTy)"; } def fir_ConvertOp : fir_OneResultOp<"convert", [NoSideEffect]> { let summary = "encapsulates all Fortran scalar type conversions"; - + let description = [{ Generalized type conversion. Convert the ssa value from type T to type U. Not all pairs of types have conversions. When types T and U are the same @@ -2559,12 +3144,13 @@ static bool isFloatCompatible(mlir::Type ty); static bool isPointerCompatible(mlir::Type ty); }]; + let hasCanonicalizer = 1; } def FortranTypeAttr : Attr()">, Or<[CPred<"$_self.cast().getValue().isa()">, - CPred<"$_self.cast().getValue().isa()">, - CPred<"$_self.cast().getValue().isa()">, + CPred<"$_self.cast().getValue().isa()">, + CPred<"$_self.cast().getValue().isa()">, CPred<"$_self.cast().getValue().isa()">, CPred<"$_self.cast().getValue().isa()">, CPred<"$_self.cast().getValue().isa()">]>]>, @@ -2605,9 +3191,7 @@ p.printOptionalAttrDict(getAttrs(), {"in_type"}); }]; - let builders = [ - OpBuilderDAG<(ins "mlir::TypeAttr":$inty)> - ]; + let builders = [OpBuilderDAG<(ins "mlir::TypeAttr":$inty)>]; let verifier = [{ mlir::Type resultTy = getType(); @@ -2711,22 +3295,24 @@ let skipDefaultBuilders = 1; let builders = [ - OpBuilderDAG<(ins "StringRef":$name, "Type":$type, - CArg<"ArrayRef", "{}">:$attrs)>, - OpBuilderDAG<(ins "StringRef":$name, "bool":$isConstant, "Type":$type, - CArg<"ArrayRef", "{}">:$attrs)>, - OpBuilderDAG<(ins "StringRef":$name, "Type":$type, - CArg<"StringAttr", "{}">:$linkage, - CArg<"ArrayRef", "{}">:$attrs)>, - OpBuilderDAG<(ins "StringRef":$name, "bool":$isConstant, "Type":$type, - CArg<"StringAttr", "{}">:$linkage, - CArg<"ArrayRef", "{}">:$attrs)>, - OpBuilderDAG<(ins "StringRef":$name, "Type":$type, "Attribute":$initVal, - CArg<"StringAttr", "{}">:$linkage, - CArg<"ArrayRef", "{}">:$attrs)>, - OpBuilderDAG<(ins "StringRef":$name, "bool":$isConstant, "Type":$type, - "Attribute":$initVal, CArg<"StringAttr", "{}">:$linkage, - CArg<"ArrayRef", "{}">:$attrs)>, + OpBuilderDAG<(ins "llvm::StringRef":$name, "mlir::Type":$type, + CArg<"llvm::ArrayRef", "{}">:$attrs)>, + OpBuilderDAG<(ins "llvm::StringRef":$name, "bool":$isConstant, + "mlir::Type":$type, + CArg<"llvm::ArrayRef", "{}">:$attrs)>, + OpBuilderDAG<(ins "llvm::StringRef":$name, "mlir::Type":$type, + CArg<"mlir::StringAttr", "{}">:$linkage, + CArg<"llvm::ArrayRef", "{}">:$attrs)>, + OpBuilderDAG<(ins "llvm::StringRef":$name, "bool":$isConstant, + "mlir::Type":$type, CArg<"mlir::StringAttr", "{}">:$linkage, + CArg<"llvm::ArrayRef", "{}">:$attrs)>, + OpBuilderDAG<(ins "llvm::StringRef":$name, "mlir::Type":$type, + "mlir::Attribute":$initVal, CArg<"mlir::StringAttr", "{}">:$linkage, + CArg<"llvm::ArrayRef", "{}">:$attrs)>, + OpBuilderDAG<(ins "llvm::StringRef":$name, "bool":$isConstant, + "mlir::Type":$type, "mlir::Attribute":$initVal, + CArg<"mlir::StringAttr", "{}">:$linkage, + CArg<"llvm::ArrayRef", "{}">:$attrs)>, ]; let extraClassDeclaration = [{ @@ -2745,7 +3331,7 @@ mlir::Type resultType() { return fir::AllocaOp::wrapResultType(getType()); } - + /// Return the initializer attribute if it exists, or a null attribute. Attribute getValueOrNull() { return initVal().getValueOr(Attribute()); } @@ -2887,8 +3473,8 @@ let skipDefaultBuilders = 1; let builders = [ - OpBuilderDAG<(ins "StringRef":$name, "Type":$type, - CArg<"ArrayRef", "{}">:$attrs), + OpBuilderDAG<(ins "llvm::StringRef":$name, "mlir::Type":$type, + CArg<"llvm::ArrayRef", "{}">:$attrs), [{ $_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), $_builder.getStringAttr(name)); Index: flang/include/flang/Optimizer/Dialect/FIRType.h =================================================================== --- flang/include/flang/Optimizer/Dialect/FIRType.h +++ flang/include/flang/Optimizer/Dialect/FIRType.h @@ -5,6 +5,10 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// #ifndef OPTIMIZER_DIALECT_FIRTYPE_H #define OPTIMIZER_DIALECT_FIRTYPE_H @@ -32,18 +36,17 @@ class FIROpsDialect; -using KindTy = int; +using KindTy = unsigned; namespace detail { struct BoxTypeStorage; struct BoxCharTypeStorage; struct BoxProcTypeStorage; struct CharacterTypeStorage; -struct CplxTypeStorage; -struct DimsTypeStorage; +struct ComplexTypeStorage; struct FieldTypeStorage; struct HeapTypeStorage; -struct IntTypeStorage; +struct IntegerTypeStorage; struct LenTypeStorage; struct LogicalTypeStorage; struct PointerTypeStorage; @@ -51,7 +54,11 @@ struct RecordTypeStorage; struct ReferenceTypeStorage; struct SequenceTypeStorage; +struct ShapeTypeStorage; +struct ShapeShiftTypeStorage; +struct SliceTypeStorage; struct TypeDescTypeStorage; +struct VectorTypeStorage; } // namespace detail // These isa_ routines follow the precedent of llvm::isa_or_null<> @@ -68,7 +75,9 @@ /// Is `t` a FIR dialect type that implies a memory (de)reference? bool isa_ref_type(mlir::Type t); -/// Is `t` a type that is always trivially pass-by-reference? +/// Is `t` a type that is always trivially pass-by-reference? Specifically, this +/// is testing if `t` is a ReferenceType or any box type. Compare this to +/// conformsWithPassByRef(), which includes pointers and allocatables. bool isa_passbyref_type(mlir::Type t); /// Is `t` a boxed type? @@ -91,25 +100,44 @@ // Intrinsic types /// Model of the Fortran CHARACTER intrinsic type, including the KIND type -/// parameter. The model does not include a LEN type parameter. A CharacterType -/// is thus the type of a single character value. +/// parameter. The model optionally includes a LEN type parameter. A +/// CharacterType is thus the type of both a single character value and a +/// character with a LEN parameter. class CharacterType : public mlir::Type::TypeBase { public: using Base::Base; - static CharacterType get(mlir::MLIRContext *ctxt, KindTy kind); + using LenType = std::int64_t; + + static CharacterType get(mlir::MLIRContext *ctxt, KindTy kind, LenType len); + /// Return unknown length CHARACTER type. + static CharacterType getUnknownLen(mlir::MLIRContext *ctxt, KindTy kind) { + return get(ctxt, kind, unknownLen()); + } + /// Return length 1 CHARACTER type. + static CharacterType getSingleton(mlir::MLIRContext *ctxt, KindTy kind) { + return get(ctxt, kind, singleton()); + } KindTy getFKind() const; + + /// CHARACTER is a singleton and has a LEN of 1. + static constexpr LenType singleton() { return 1; } + /// CHARACTER has an unknown LEN property. + static constexpr LenType unknownLen() { return -1; } + + /// Access to a CHARACTER's LEN property. Defaults to 1. + LenType getLen() const; }; /// Model of a Fortran COMPLEX intrinsic type, including the KIND type /// parameter. COMPLEX is a floating point type with a real and imaginary /// member. -class CplxType : public mlir::Type::TypeBase { +class ComplexType : public mlir::Type::TypeBase { public: using Base::Base; - static CplxType get(mlir::MLIRContext *ctxt, KindTy kind); + static fir::ComplexType get(mlir::MLIRContext *ctxt, KindTy kind); /// Get the corresponding fir.real type. mlir::Type getElementType() const; @@ -119,19 +147,18 @@ /// Model of a Fortran INTEGER intrinsic type, including the KIND type /// parameter. -class IntType - : public mlir::Type::TypeBase { +class IntegerType : public mlir::Type::TypeBase { public: using Base::Base; - static IntType get(mlir::MLIRContext *ctxt, KindTy kind); + static fir::IntegerType get(mlir::MLIRContext *ctxt, KindTy kind); KindTy getFKind() const; }; /// Model of a Fortran LOGICAL intrinsic type, including the KIND type /// parameter. -class LogicalType - : public mlir::Type::TypeBase { +class LogicalType : public mlir::Type::TypeBase { public: using Base::Base; static LogicalType get(mlir::MLIRContext *ctxt, KindTy kind); @@ -192,17 +219,38 @@ mlir::Type eleTy); }; -/// The type of a runtime vector that describes triples of array dimension -/// information. A triple consists of a lower bound, upper bound, and -/// stride. Each dimension of an array entity may have an associated triple that -/// maps how elements of the array are accessed. -class DimsType : public mlir::Type::TypeBase { +/// Type of a vector of runtime values that define the shape of a +/// multidimensional array object. The vector is the extents of each array +/// dimension. The rank of a ShapeType must be at least 1. +class ShapeType : public mlir::Type::TypeBase { public: using Base::Base; - static DimsType get(mlir::MLIRContext *ctx, unsigned rank); + static ShapeType get(mlir::MLIRContext *ctx, unsigned rank); + unsigned getRank() const; +}; - /// returns -1 if the rank is unknown +/// Type of a vector of runtime values that define the shape and the origin of a +/// multidimensional array object. The vector is of pairs, origin offset and +/// extent, of each array dimension. The rank of a ShapeShiftType must be at +/// least 1. +class ShapeShiftType + : public mlir::Type::TypeBase { +public: + using Base::Base; + static ShapeShiftType get(mlir::MLIRContext *ctx, unsigned rank); + unsigned getRank() const; +}; + +/// Type of a vector that represents an array slice operation on an array. +/// Fortran slices are triples of lower bound, upper bound, and stride. The rank +/// of a SliceType must be at least 1. +class SliceType : public mlir::Type::TypeBase { +public: + using Base::Base; + static SliceType get(mlir::MLIRContext *ctx, unsigned rank); unsigned getRank() const; }; @@ -373,14 +421,6 @@ llvm::StringRef name); }; -mlir::Type parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser); - -void printFirType(FIROpsDialect *, mlir::Type ty, mlir::DialectAsmPrinter &p); - -/// Guarantee `type` is a scalar integral type (standard Integer, standard -/// Index, or FIR Int). Aborts execution if condition is false. -void verifyIntegralType(mlir::Type type); - /// Is `t` a FIR Real or MLIR Float type? inline bool isa_real(mlir::Type t) { return t.isa() || t.isa(); @@ -389,14 +429,51 @@ /// Is `t` an integral type? inline bool isa_integer(mlir::Type t) { return t.isa() || t.isa() || - t.isa(); + t.isa(); } +/// Replacement for the standard dialect's vector type. Relaxes some of the +/// constraints and imposes some new ones. +class VectorType : public mlir::Type::TypeBase { +public: + using Base::Base; + + static fir::VectorType get(uint64_t len, mlir::Type eleTy); + mlir::Type getEleTy() const; + uint64_t getLen() const; + + static mlir::LogicalResult + verifyConstructionInvariants(mlir::Location, uint64_t len, mlir::Type eleTy); + static bool isValidElementType(mlir::Type t) { + return isa_real(t) || isa_integer(t); + } +}; + +mlir::Type parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser); + +void printFirType(FIROpsDialect *, mlir::Type ty, mlir::DialectAsmPrinter &p); + +/// Guarantee `type` is a scalar integral type (standard Integer, standard +/// Index, or FIR Int). Aborts execution if condition is false. +void verifyIntegralType(mlir::Type type); + /// Is `t` a FIR or MLIR Complex type? inline bool isa_complex(mlir::Type t) { - return t.isa() || t.isa(); + return t.isa() || t.isa(); } +inline bool isa_char_string(mlir::Type t) { + if (auto ct = t.dyn_cast_or_null()) + return ct.getLen() != fir::CharacterType::singleton(); + return false; +} + +/// Is `t` a box type for which it is not possible to deduce the box size. +/// It is not possible to deduce the size of a box that describes an entity +/// of unknown rank or type. +bool isa_unknown_size_box(mlir::Type t); + } // namespace fir #endif // OPTIMIZER_DIALECT_FIRTYPE_H Index: flang/include/flang/Optimizer/TypePredicates.td =================================================================== --- /dev/null +++ flang/include/flang/Optimizer/TypePredicates.td @@ -0,0 +1,52 @@ +//===-- TypePredicates.td ----------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef FIR_OPTIMIZER_TYPEPREDICATES +#define FIR_OPTIMIZER_TYPEPREDICATES + +def fir_IntegerType : Type()">, + "FIR integer type">; + +def AnyIntegerLike : TypeConstraint, "any integer">; + +def AnyIntegerType : Type; + +// Reference to an entity type +def fir_ReferenceType : Type()">, + "reference type">; + +// Reference to an ALLOCATABLE attribute type +def fir_HeapType : Type()">, + "allocatable type">; + +// Reference to a POINTER attribute type +def fir_PointerType : Type()">, + "pointer type">; + +// Reference types +def AnyReferenceLike : TypeConstraint, "any reference">; + +// A field (in a RecordType) argument's type +def fir_FieldType : Type()">, "field type">; + +// A LEN parameter (in a RecordType) argument's type +def fir_LenType : Type()">, + "LEN parameter type">; + +// A descriptor tuple (captures a reference to an entity and other information) +def fir_BoxType : Type()">, "box type">; + +def AnyCoordinateLike : TypeConstraint, "any coordinate index">; + +def AnyCoordinateType : Type; + +#endif Index: flang/lib/Optimizer/Dialect/FIRAttr.cpp =================================================================== --- flang/lib/Optimizer/Dialect/FIRAttr.cpp +++ flang/lib/Optimizer/Dialect/FIRAttr.cpp @@ -5,19 +5,22 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// #include "flang/Optimizer/Dialect/FIRAttr.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Support/KindMapping.h" #include "mlir/IR/AttributeSupport.h" #include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/Types.h" +#include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/SmallString.h" using namespace fir; -namespace fir { -namespace detail { +namespace fir::detail { struct RealAttributeStorage : public mlir::AttributeStorage { using KeyTy = std::pair; @@ -71,55 +74,98 @@ private: mlir::Type value; }; -} // namespace detail -ExactTypeAttr ExactTypeAttr::get(mlir::Type value) { +/// An attribute representing a raw pointer. +struct OpaqueAttributeStorage : public mlir::AttributeStorage { + using KeyTy = void *; + + OpaqueAttributeStorage(void *value) : value(value) {} + + /// Key equality function. + bool operator==(const KeyTy &key) const { return key == value; } + + /// Construct a new storage instance. + static OpaqueAttributeStorage * + construct(mlir::AttributeStorageAllocator &allocator, KeyTy key) { + return new (allocator.allocate()) + OpaqueAttributeStorage(key); + } + + void *getPointer() const { return value; } + +private: + void *value; +}; +} // namespace fir::detail + +//===----------------------------------------------------------------------===// +// Attributes for SELECT TYPE +//===----------------------------------------------------------------------===// + +ExactTypeAttr fir::ExactTypeAttr::get(mlir::Type value) { return Base::get(value.getContext(), value); } -mlir::Type ExactTypeAttr::getType() const { return getImpl()->getType(); } +mlir::Type fir::ExactTypeAttr::getType() const { return getImpl()->getType(); } -SubclassAttr SubclassAttr::get(mlir::Type value) { +SubclassAttr fir::SubclassAttr::get(mlir::Type value) { return Base::get(value.getContext(), value); } -mlir::Type SubclassAttr::getType() const { return getImpl()->getType(); } +mlir::Type fir::SubclassAttr::getType() const { return getImpl()->getType(); } + +//===----------------------------------------------------------------------===// +// Attributes for SELECT CASE +//===----------------------------------------------------------------------===// using AttributeUniquer = mlir::detail::AttributeUniquer; -ClosedIntervalAttr ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) { +ClosedIntervalAttr fir::ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) { return AttributeUniquer::get(ctxt); } -UpperBoundAttr UpperBoundAttr::get(mlir::MLIRContext *ctxt) { +UpperBoundAttr fir::UpperBoundAttr::get(mlir::MLIRContext *ctxt) { return AttributeUniquer::get(ctxt); } -LowerBoundAttr LowerBoundAttr::get(mlir::MLIRContext *ctxt) { +LowerBoundAttr fir::LowerBoundAttr::get(mlir::MLIRContext *ctxt) { return AttributeUniquer::get(ctxt); } -PointIntervalAttr PointIntervalAttr::get(mlir::MLIRContext *ctxt) { +PointIntervalAttr fir::PointIntervalAttr::get(mlir::MLIRContext *ctxt) { return AttributeUniquer::get(ctxt); } +//===----------------------------------------------------------------------===// // RealAttr +//===----------------------------------------------------------------------===// -RealAttr RealAttr::get(mlir::MLIRContext *ctxt, - const RealAttr::ValueType &key) { +RealAttr fir::RealAttr::get(mlir::MLIRContext *ctxt, + const RealAttr::ValueType &key) { return Base::get(ctxt, key); } -int RealAttr::getFKind() const { return getImpl()->getFKind(); } +int fir::RealAttr::getFKind() const { return getImpl()->getFKind(); } + +llvm::APFloat fir::RealAttr::getValue() const { return getImpl()->getValue(); } -llvm::APFloat RealAttr::getValue() const { return getImpl()->getValue(); } +//===----------------------------------------------------------------------===// +// OpaqueAttr +//===----------------------------------------------------------------------===// + +OpaqueAttr fir::OpaqueAttr::get(mlir::MLIRContext *ctxt, void *key) { + return Base::get(ctxt, key); +} +void *fir::OpaqueAttr::getPointer() const { return getImpl()->getPointer(); } + +//===----------------------------------------------------------------------===// // FIR attribute parsing +//===----------------------------------------------------------------------===// -namespace { -mlir::Attribute parseFirRealAttr(FIROpsDialect *dialect, - mlir::DialectAsmParser &parser, - mlir::Type type) { +static mlir::Attribute parseFirRealAttr(FIROpsDialect *dialect, + mlir::DialectAsmParser &parser, + mlir::Type type) { int kind = 0; if (parser.parseLess() || parser.parseInteger(kind) || parser.parseComma()) { parser.emitError(parser.getNameLoc(), "expected '<' kind ','"); @@ -154,11 +200,10 @@ } return RealAttr::get(dialect->getContext(), {kind, value}); } -} // namespace -mlir::Attribute parseFirAttribute(FIROpsDialect *dialect, - mlir::DialectAsmParser &parser, - mlir::Type type) { +mlir::Attribute fir::parseFirAttribute(FIROpsDialect *dialect, + mlir::DialectAsmParser &parser, + mlir::Type type) { auto loc = parser.getNameLoc(); llvm::StringRef attrName; if (parser.parseKeyword(&attrName)) { @@ -182,6 +227,15 @@ } return SubclassAttr::get(type); } + if (attrName == OpaqueAttr::getAttrName()) { + if (parser.parseLess() || parser.parseGreater()) { + parser.emitError(loc, "expected <>"); + return {}; + } + // NB: opaque pointers are always parsed in as nullptrs. The tool must + // rebuild the context. + return OpaqueAttr::get(dialect->getContext(), nullptr); + } if (attrName == PointIntervalAttr::getAttrName()) return PointIntervalAttr::get(dialect->getContext()); if (attrName == LowerBoundAttr::getAttrName()) @@ -197,10 +251,12 @@ return {}; } +//===----------------------------------------------------------------------===// // FIR attribute pretty printer +//===----------------------------------------------------------------------===// -void printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr, - mlir::DialectAsmPrinter &p) { +void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr, + mlir::DialectAsmPrinter &p) { auto &os = p.getStream(); if (auto exact = attr.dyn_cast()) { os << fir::ExactTypeAttr::getAttrName() << '<'; @@ -223,9 +279,10 @@ llvm::SmallString<40> ss; a.getValue().bitcastToAPInt().toStringUnsigned(ss, 16); os << ss << '>'; + } else if (attr.isa()) { + os << fir::OpaqueAttr::getAttrName() << "<>"; } else { - llvm_unreachable("attribute pretty-printer is not implemented"); + // don't know how to print the attribute, so use a default + os << "<(unknown attribute)>"; } } - -} // namespace fir Index: flang/lib/Optimizer/Dialect/FIRDialect.cpp =================================================================== --- flang/lib/Optimizer/Dialect/FIRDialect.cpp +++ flang/lib/Optimizer/Dialect/FIRDialect.cpp @@ -5,25 +5,73 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIRAttr.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Transforms/InliningUtils.h" using namespace fir; +namespace { +/// This class defines the interface for handling inlining of FIR calls. +struct FIRInlinerInterface : public mlir::DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(mlir::Operation *call, mlir::Operation *callable, + bool wouldBeCloned) const final { + return fir::canLegallyInline(call, callable, wouldBeCloned); + } + + /// This hook checks to see if the operation `op` is legal to inline into the + /// given region `reg`. + bool isLegalToInline(mlir::Operation *op, mlir::Region *reg, + bool wouldBeCloned, + mlir::BlockAndValueMapping &map) const final { + return fir::canLegallyInline(op, reg, wouldBeCloned, map); + } + + /// This hook is called when a terminator operation has been inlined. + /// We handle the return (a Fortran FUNCTION) by replacing the values + /// previously returned by the call operation with the operands of the + /// return. + void handleTerminator(mlir::Operation *op, + llvm::ArrayRef valuesToRepl) const final { + auto returnOp = cast(op); + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()].replaceAllUsesWith(it.value()); + } + + mlir::Operation *materializeCallConversion(mlir::OpBuilder &builder, + mlir::Value input, + mlir::Type resultType, + mlir::Location loc) const { + return builder.create(loc, resultType, input); + } +}; +} // namespace + fir::FIROpsDialect::FIROpsDialect(mlir::MLIRContext *ctx) : mlir::Dialect("fir", ctx, mlir::TypeID::get()) { - addTypes(); - addAttributes(); + addAttributes(); addOperations< #define GET_OP_LIST #include "flang/Optimizer/Dialect/FIROps.cpp.inc" >(); + addInterfaces(); } // anchor the class vtable to this compilation unit Index: flang/lib/Optimizer/Dialect/FIROps.cpp =================================================================== --- flang/lib/Optimizer/Dialect/FIROps.cpp +++ flang/lib/Optimizer/Dialect/FIROps.cpp @@ -5,6 +5,10 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRAttr.h" @@ -15,6 +19,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/TypeSwitch.h" @@ -115,6 +120,21 @@ return HeapType::get(intype); } +//===----------------------------------------------------------------------===// +// ArrayLoadOp +//===----------------------------------------------------------------------===// + +std::vector fir::ArrayLoadOp::getExtents() { + std::vector result; + if (auto sh = shape()) + if (auto *op = sh.getDefiningOp()) { + if (auto shOp = dyn_cast(op)) + return shOp.getExtents(); + return cast(op).getExtents(); + } + return result; +} + //===----------------------------------------------------------------------===// // BoxAddrOp //===----------------------------------------------------------------------===// @@ -158,6 +178,11 @@ // CallOp //===----------------------------------------------------------------------===// +mlir::FunctionType fir::CallOp::getFunctionType() { + return mlir::FunctionType::get(getContext(), getOperandTypes(), + getResultTypes()); +} + static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) { auto callee = op.callee(); bool isDirect = callee.hasValue(); @@ -203,12 +228,9 @@ } else { auto funcArgs = llvm::ArrayRef(operands).drop_front(); - llvm::SmallVector resultArgs( - result.operands.begin() + (result.operands.empty() ? 0 : 1), - result.operands.end()); if (parser.resolveOperand(operands[0], funcType, result.operands) || parser.resolveOperands(funcArgs, funcType.getInputs(), - parser.getNameLoc(), resultArgs)) + parser.getNameLoc(), result.operands)) return mlir::failure(); } result.addTypes(funcType.getResults()); @@ -315,6 +337,10 @@ // ConvertOp //===----------------------------------------------------------------------===// +void fir::ConvertOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { +} + mlir::OpFoldResult fir::ConvertOp::fold(llvm::ArrayRef opnds) { if (value().getType() == getType()) return value(); @@ -337,7 +363,7 @@ bool fir::ConvertOp::isIntegerCompatible(mlir::Type ty) { return ty.isa() || ty.isa() || - ty.isa() || ty.isa() || + ty.isa() || ty.isa() || ty.isa(); } @@ -348,7 +374,7 @@ bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) { return ty.isa() || ty.isa() || ty.isa() || ty.isa() || - ty.isa(); + ty.isa() || ty.isa(); } //===----------------------------------------------------------------------===// @@ -357,31 +383,19 @@ static mlir::ParseResult parseCoordinateOp(mlir::OpAsmParser &parser, mlir::OperationState &result) { - llvm::ArrayRef allOperandTypes; - llvm::ArrayRef allResultTypes; - llvm::SMLoc allOperandLoc = parser.getCurrentLocation(); + auto loc = parser.getCurrentLocation(); llvm::SmallVector allOperands; - if (parser.parseOperandList(allOperands)) - return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - if (parser.parseColon()) - return failure(); - mlir::FunctionType funcTy; - if (parser.parseType(funcTy)) - return failure(); - allOperandTypes = funcTy.getInputs(); - allResultTypes = funcTy.getResults(); - result.addTypes(allResultTypes); - if (parser.resolveOperands(allOperands, allOperandTypes, allOperandLoc, + if (parser.parseOperandList(allOperands) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(funcTy) || + parser.resolveOperands(allOperands, funcTy.getInputs(), loc, result.operands)) return failure(); - if (funcTy.getNumInputs()) { - // No inputs handled by verify + parser.addTypesToList(funcTy.getResults(), result.types); + if (funcTy.getNumInputs()) result.addAttribute(fir::CoordinateOp::baseType(), mlir::TypeAttr::get(funcTy.getInput(0))); - } return success(); } @@ -416,8 +430,8 @@ //===----------------------------------------------------------------------===// mlir::FunctionType fir::DispatchOp::getFunctionType() { - auto attr = (*this)->getAttr("fn_type").cast(); - return attr.getValue().cast(); + return mlir::FunctionType::get(getContext(), getOperandTypes(), + getResultTypes()); } //===----------------------------------------------------------------------===// @@ -430,47 +444,6 @@ block.getOperations().insert(block.end(), op); } -//===----------------------------------------------------------------------===// -// EmboxOp -//===----------------------------------------------------------------------===// - -static mlir::ParseResult parseEmboxOp(mlir::OpAsmParser &parser, - mlir::OperationState &result) { - mlir::FunctionType type; - llvm::SmallVector operands; - mlir::OpAsmParser::OperandType memref; - if (parser.parseOperand(memref)) - return mlir::failure(); - operands.push_back(memref); - auto &builder = parser.getBuilder(); - if (!parser.parseOptionalLParen()) { - if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || - parser.parseRParen()) - return mlir::failure(); - auto lens = builder.getI32IntegerAttr(operands.size()); - result.addAttribute(fir::EmboxOp::lenpName(), lens); - } - if (!parser.parseOptionalComma()) { - mlir::OpAsmParser::OperandType dims; - if (parser.parseOperand(dims)) - return mlir::failure(); - operands.push_back(dims); - } else if (!parser.parseOptionalLSquare()) { - mlir::AffineMapAttr map; - if (parser.parseAttribute(map, fir::EmboxOp::layoutName(), - result.attributes) || - parser.parseRSquare()) - return mlir::failure(); - } - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type) || - parser.resolveOperands(operands, type.getInputs(), parser.getNameLoc(), - result.operands) || - parser.addTypesToList(type.getResults(), result.types)) - return mlir::failure(); - return mlir::success(); -} - //===----------------------------------------------------------------------===// // GenTypeDescOp //===----------------------------------------------------------------------===// @@ -491,7 +464,7 @@ auto &builder = parser.getBuilder(); if (mlir::succeeded(parser.parseOptionalKeyword(&linkage))) { if (fir::GlobalOp::verifyValidLinkage(linkage)) - return failure(); + return mlir::failure(); mlir::StringAttr linkAttr = builder.getStringAttr(linkage); result.addAttribute(fir::GlobalOp::linkageAttrName(), linkAttr); } @@ -500,7 +473,7 @@ mlir::SymbolRefAttr nameAttr; if (parser.parseAttribute(nameAttr, fir::GlobalOp::symbolAttrName(), result.attributes)) - return failure(); + return mlir::failure(); result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(nameAttr.getRootReference())); @@ -510,7 +483,7 @@ if (parser.parseAttribute(attr, fir::GlobalOp::initValAttrName(), result.attributes) || parser.parseRParen()) - return failure(); + return mlir::failure(); simpleInitializer = true; } @@ -522,7 +495,7 @@ mlir::Type globalType; if (parser.parseColonType(globalType)) - return failure(); + return mlir::failure(); result.addAttribute(fir::GlobalOp::typeAttrName(), mlir::TypeAttr::get(globalType)); @@ -531,11 +504,13 @@ result.addRegion(); } else { // Parse the optional initializer body. - if (parser.parseRegion(*result.addRegion(), llvm::None, llvm::None)) - return failure(); + auto parseResult = parser.parseOptionalRegion( + *result.addRegion(), /*arguments=*/llvm::None, /*argTypes=*/llvm::None); + if (parseResult.hasValue() && mlir::failed(*parseResult)) + return mlir::failure(); } - return success(); + return mlir::success(); } void fir::GlobalOp::appendInitialValue(mlir::Operation *op) { @@ -592,11 +567,74 @@ mlir::ParseResult fir::GlobalOp::verifyValidLinkage(StringRef linkage) { // Supporting only a subset of the LLVM linkage types for now - static const llvm::SmallVector validNames = { - "internal", "common", "weak"}; + static const char *validNames[] = {"common", "internal", "linkonce", "weak"}; return mlir::success(llvm::is_contained(validNames, linkage)); } +//===----------------------------------------------------------------------===// +// InsertValueOp +//===----------------------------------------------------------------------===// + +static bool checkIsIntegerConstant(mlir::Value v, int64_t conVal) { + if (auto c = dyn_cast_or_null(v.getDefiningOp())) { + auto attr = c.getValue(); + if (auto iattr = attr.dyn_cast()) + return iattr.getInt() == conVal; + } + return false; +} +static bool isZero(mlir::Value v) { return checkIsIntegerConstant(v, 0); } +static bool isOne(mlir::Value v) { return checkIsIntegerConstant(v, 1); } + +// These patterns are written by hand because the tablegen pattern language +// isn't adequate here. +template +struct UndoComplexPattern : public mlir::RewritePattern { + UndoComplexPattern(mlir::MLIRContext *ctx) + : mlir::RewritePattern("fir.insert_value", {}, 2, ctx) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto insval = dyn_cast_or_null(op); + if (!insval || !insval.getType().isa()) + return mlir::failure(); + auto insval2 = + dyn_cast_or_null(insval.adt().getDefiningOp()); + if (!insval2 || !isa(insval2.adt().getDefiningOp())) + return mlir::failure(); + auto binf = dyn_cast_or_null(insval.val().getDefiningOp()); + auto binf2 = dyn_cast_or_null(insval2.val().getDefiningOp()); + if (!binf || !binf2 || insval.coor().size() != 1 || + !isOne(insval.coor()[0]) || insval2.coor().size() != 1 || + !isZero(insval2.coor()[0])) + return mlir::failure(); + auto eai = + dyn_cast_or_null(binf.lhs().getDefiningOp()); + auto ebi = + dyn_cast_or_null(binf.rhs().getDefiningOp()); + auto ear = + dyn_cast_or_null(binf2.lhs().getDefiningOp()); + auto ebr = + dyn_cast_or_null(binf2.rhs().getDefiningOp()); + if (!eai || !ebi || !ear || !ebr || ear.adt() != eai.adt() || + ebr.adt() != ebi.adt() || eai.coor().size() != 1 || + !isOne(eai.coor()[0]) || ebi.coor().size() != 1 || + !isOne(ebi.coor()[0]) || ear.coor().size() != 1 || + !isZero(ear.coor()[0]) || ebr.coor().size() != 1 || + !isZero(ebr.coor()[0])) + return mlir::failure(); + rewriter.replaceOpWithNewOp(op, ear.adt(), ebr.adt()); + return mlir::success(); + } +}; + +void fir::InsertValueOp::getCanonicalizationPatterns( + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { + results.insert, + UndoComplexPattern>(context); +} + //===----------------------------------------------------------------------===// // IterWhileOp //===----------------------------------------------------------------------===// @@ -604,9 +642,14 @@ void fir::IterWhileOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value lb, mlir::Value ub, mlir::Value step, - mlir::Value iterate, mlir::ValueRange iterArgs, + mlir::Value iterate, bool finalCountValue, + mlir::ValueRange iterArgs, llvm::ArrayRef attributes) { result.addOperands({lb, ub, step, iterate}); + if (finalCountValue) { + result.addTypes(builder.getIndexType()); + result.addAttribute(finalValueAttrName(), builder.getUnitAttr()); + } result.addTypes(iterate.getType()); result.addOperands(iterArgs); for (auto v : iterArgs) @@ -648,24 +691,50 @@ // Parse the initial iteration arguments. llvm::SmallVector regionArgs; + auto prependCount = false; + // Induction variable. regionArgs.push_back(inductionVariable); regionArgs.push_back(iterateVar); - result.addTypes(i1Type); - if (mlir::succeeded(parser.parseOptionalKeyword("iter_args"))) { + if (succeeded(parser.parseOptionalKeyword("iter_args"))) { llvm::SmallVector operands; llvm::SmallVector regionTypes; // Parse assignment list and results type list. if (parser.parseAssignmentList(regionArgs, operands) || parser.parseArrowTypeList(regionTypes)) - return mlir::failure(); + return failure(); + if (regionTypes.size() == operands.size() + 2) + prependCount = true; + llvm::ArrayRef resTypes = regionTypes; + resTypes = prependCount ? resTypes.drop_front(2) : resTypes; // Resolve input operands. - for (auto operand_type : llvm::zip(operands, regionTypes)) + for (auto operand_type : llvm::zip(operands, resTypes)) if (parser.resolveOperand(std::get<0>(operand_type), std::get<1>(operand_type), result.operands)) - return mlir::failure(); - result.addTypes(regionTypes); + return failure(); + if (prependCount) { + // This is an assert here, because these types are verified. + assert(regionTypes[0].isa() && + regionTypes[1].isSignlessInteger(1)); + result.addTypes(regionTypes); + } else { + result.addTypes(i1Type); + result.addTypes(resTypes); + } + } else if (succeeded(parser.parseOptionalArrow())) { + llvm::SmallVector typeList; + if (parser.parseLParen() || parser.parseTypeList(typeList) || + parser.parseRParen()) + return failure(); + // Type list must be "(index, i1)". + if (typeList.size() != 2 || !typeList[0].isa() || + !typeList[1].isSignlessInteger(1)) + return failure(); + result.addTypes(typeList); + prependCount = true; + } else { + result.addTypes(i1Type); } if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) @@ -673,7 +742,11 @@ llvm::SmallVector argTypes; // Induction variable (hidden) - argTypes.push_back(indexType); + if (prependCount) + result.addAttribute(IterWhileOp::finalValueAttrName(), + builder.getUnitAttr()); + else + argTypes.push_back(indexType); // Loop carried variables (including iterate) argTypes.append(result.types.begin(), result.types.end()); // Parse the body region. @@ -692,10 +765,6 @@ } static mlir::LogicalResult verify(fir::IterWhileOp op) { - if (auto cst = dyn_cast_or_null(op.step().getDefiningOp())) - if (cst.getValue() <= 0) - return op.emitOpError("constant step operand must be positive"); - // Check that the body defines as single block argument for the induction // variable. auto *body = op.getBody(); @@ -709,6 +778,19 @@ "the induction variable"); auto opNumResults = op.getNumResults(); + if (op.finalValue()) { + // Result type must be "(index, i1, ...)". + if (!op.getResult(0).getType().isa()) + return op.emitOpError("result #0 expected to be index"); + if (!op.getResult(1).getType().isSignlessInteger(1)) + return op.emitOpError("result #1 expected to be i1"); + opNumResults--; + } else { + // iterate_while always returns the early exit induction value. + // Result type must be "(i1, ...)" + if (!op.getResult(0).getType().isSignlessInteger(1)) + return op.emitOpError("result #0 expected to be i1"); + } if (opNumResults == 0) return mlir::failure(); if (op.getNumIterOperands() != opNumResults) @@ -719,7 +801,8 @@ "mismatch in number of basic block args and defined values"); auto iterOperands = op.getIterOperands(); auto iterArgs = op.getRegionIterArgs(); - auto opResults = op.getResults(); + auto opResults = + op.finalValue() ? op.getResults().drop_front() : op.getResults(); unsigned i = 0; for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { if (std::get<0>(e).getType() != std::get<2>(e).getType()) @@ -747,9 +830,14 @@ llvm::interleaveComma( llvm::zip(regionArgs.drop_front(), operands.drop_front()), p, [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); - p << ") -> (" << op.getResultTypes().drop_front() << ')'; + auto resTypes = op.finalValue() ? op.getResultTypes() + : op.getResultTypes().drop_front(); + p << ") -> (" << resTypes << ')'; + } else if (op.finalValue()) { + p << " -> (" << op.getResultTypes() << ')'; } - p.printOptionalAttrDictWithKeyword(op->getAttrs(), {}); + p.printOptionalAttrDictWithKeyword(op->getAttrs(), + {IterWhileOp::finalValueAttrName()}); p.printRegion(op.region(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } @@ -767,6 +855,27 @@ return success(); } +mlir::BlockArgument fir::IterWhileOp::iterArgToBlockArg(mlir::Value iterArg) { + for (auto i : llvm::enumerate(initArgs())) + if (iterArg == i.value()) + return region().front().getArgument(i.index() + 1); + return {}; +} + +void fir::IterWhileOp::resultToSourceOps( + llvm::SmallVectorImpl &results, unsigned resultNum) { + auto oper = finalValue() ? resultNum : resultNum + 1; + auto *term = region().front().getTerminator(); + if (oper < term->getNumOperands()) + results.push_back(term->getOperand(oper)); +} + +mlir::Value fir::IterWhileOp::blockArgToSourceOp(unsigned blockArgNum) { + if (blockArgNum > 0 && blockArgNum <= initArgs().size()) + return initArgs()[blockArgNum - 1]; + return {}; +} + //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// @@ -786,21 +895,26 @@ } //===----------------------------------------------------------------------===// -// LoopOp +// DoLoopOp //===----------------------------------------------------------------------===// -void fir::LoopOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, - mlir::Value lb, mlir::Value ub, mlir::Value step, - bool unordered, mlir::ValueRange iterArgs, - llvm::ArrayRef attributes) { +void fir::DoLoopOp::build(mlir::OpBuilder &builder, + mlir::OperationState &result, mlir::Value lb, + mlir::Value ub, mlir::Value step, bool unordered, + bool finalCountValue, mlir::ValueRange iterArgs, + llvm::ArrayRef attributes) { result.addOperands({lb, ub, step}); result.addOperands(iterArgs); + if (finalCountValue) { + result.addTypes(builder.getIndexType()); + result.addAttribute(finalValueAttrName(), builder.getUnitAttr()); + } for (auto v : iterArgs) result.addTypes(v.getType()); mlir::Region *bodyRegion = result.addRegion(); bodyRegion->push_back(new Block{}); - if (iterArgs.empty()) - LoopOp::ensureTerminator(*bodyRegion, builder, result.location); + if (iterArgs.empty() && !finalCountValue) + DoLoopOp::ensureTerminator(*bodyRegion, builder, result.location); bodyRegion->front().addArgument(builder.getIndexType()); bodyRegion->front().addArguments(iterArgs.getTypes()); if (unordered) @@ -808,8 +922,8 @@ result.addAttributes(attributes); } -static mlir::ParseResult parseLoopOp(mlir::OpAsmParser &parser, - mlir::OperationState &result) { +static mlir::ParseResult parseDoLoopOp(mlir::OpAsmParser &parser, + mlir::OperationState &result) { auto &builder = parser.getBuilder(); mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; // Parse the induction variable followed by '='. @@ -827,12 +941,13 @@ return failure(); if (mlir::succeeded(parser.parseOptionalKeyword("unordered"))) - result.addAttribute(fir::LoopOp::unorderedAttrName(), + result.addAttribute(fir::DoLoopOp::unorderedAttrName(), builder.getUnitAttr()); // Parse the optional initial iteration arguments. llvm::SmallVector regionArgs, operands; llvm::SmallVector argTypes; + auto prependCount = false; regionArgs.push_back(inductionVariable); if (succeeded(parser.parseOptionalKeyword("iter_args"))) { @@ -840,18 +955,30 @@ if (parser.parseAssignmentList(regionArgs, operands) || parser.parseArrowTypeList(result.types)) return failure(); + if (result.types.size() == operands.size() + 1) + prependCount = true; // Resolve input operands. - for (auto operand_type : llvm::zip(operands, result.types)) + llvm::ArrayRef resTypes = result.types; + for (auto operand_type : + llvm::zip(operands, prependCount ? resTypes.drop_front() : resTypes)) if (parser.resolveOperand(std::get<0>(operand_type), std::get<1>(operand_type), result.operands)) return failure(); + } else if (succeeded(parser.parseOptionalArrow())) { + if (parser.parseKeyword("index")) + return failure(); + result.types.push_back(indexType); + prependCount = true; } if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) return mlir::failure(); // Induction variable. - argTypes.push_back(indexType); + if (prependCount) + result.addAttribute(DoLoopOp::finalValueAttrName(), builder.getUnitAttr()); + else + argTypes.push_back(indexType); // Loop carried variables argTypes.append(result.types.begin(), result.types.end()); // Parse the body region. @@ -864,26 +991,22 @@ if (parser.parseRegion(*body, regionArgs, argTypes)) return failure(); - fir::LoopOp::ensureTerminator(*body, builder, result.location); + DoLoopOp::ensureTerminator(*body, builder, result.location); return mlir::success(); } -fir::LoopOp fir::getForInductionVarOwner(mlir::Value val) { +fir::DoLoopOp fir::getForInductionVarOwner(mlir::Value val) { auto ivArg = val.dyn_cast(); if (!ivArg) return {}; assert(ivArg.getOwner() && "unlinked block argument"); auto *containingInst = ivArg.getOwner()->getParentOp(); - return dyn_cast_or_null(containingInst); + return dyn_cast_or_null(containingInst); } // Lifted from loop.loop -static mlir::LogicalResult verify(fir::LoopOp op) { - if (auto cst = dyn_cast_or_null(op.step().getDefiningOp())) - if (cst.getValue() <= 0) - return op.emitOpError("constant step operand must be positive"); - +static mlir::LogicalResult verify(fir::DoLoopOp op) { // Check that the body defines as single block argument for the induction // variable. auto *body = op.getBody(); @@ -895,6 +1018,12 @@ auto opNumResults = op.getNumResults(); if (opNumResults == 0) return success(); + + if (op.finalValue()) { + if (op.unordered()) + return op.emitOpError("unordered loop has no final value"); + opNumResults--; + } if (op.getNumIterOperands() != opNumResults) return op.emitOpError( "mismatch in number of loop-carried values and defined values"); @@ -903,7 +1032,8 @@ "mismatch in number of basic block args and defined values"); auto iterOperands = op.getIterOperands(); auto iterArgs = op.getRegionIterArgs(); - auto opResults = op.getResults(); + auto opResults = + op.finalValue() ? op.getResults().drop_front() : op.getResults(); unsigned i = 0; for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { if (std::get<0>(e).getType() != std::get<2>(e).getType()) @@ -918,9 +1048,9 @@ return success(); } -static void print(mlir::OpAsmPrinter &p, fir::LoopOp op) { +static void print(mlir::OpAsmPrinter &p, fir::DoLoopOp op) { bool printBlockTerminators = false; - p << fir::LoopOp::getOperationName() << ' ' << op.getInductionVar() << " = " + p << fir::DoLoopOp::getOperationName() << ' ' << op.getInductionVar() << " = " << op.lowerBound() << " to " << op.upperBound() << " step " << op.step(); if (op.unordered()) p << " unordered"; @@ -933,26 +1063,57 @@ }); p << ") -> (" << op.getResultTypes() << ')'; printBlockTerminators = true; + } else if (op.finalValue()) { + p << " -> " << op.getResultTypes(); + printBlockTerminators = true; } p.printOptionalAttrDictWithKeyword(op->getAttrs(), - {fir::LoopOp::unorderedAttrName()}); + {fir::DoLoopOp::unorderedAttrName(), + fir::DoLoopOp::finalValueAttrName()}); p.printRegion(op.region(), /*printEntryBlockArgs=*/false, printBlockTerminators); } -mlir::Region &fir::LoopOp::getLoopBody() { return region(); } +mlir::Region &fir::DoLoopOp::getLoopBody() { return region(); } -bool fir::LoopOp::isDefinedOutsideOfLoop(mlir::Value value) { +bool fir::DoLoopOp::isDefinedOutsideOfLoop(mlir::Value value) { return !region().isAncestor(value.getParentRegion()); } mlir::LogicalResult -fir::LoopOp::moveOutOfLoop(llvm::ArrayRef ops) { +fir::DoLoopOp::moveOutOfLoop(llvm::ArrayRef ops) { for (auto op : ops) op->moveBefore(*this); return success(); } +/// Translate a value passed as an iter_arg to the corresponding block +/// argument in the body of the loop. +mlir::BlockArgument fir::DoLoopOp::iterArgToBlockArg(mlir::Value iterArg) { + for (auto i : llvm::enumerate(initArgs())) + if (iterArg == i.value()) + return region().front().getArgument(i.index() + 1); + return {}; +} + +/// Translate the result vector (by index number) to the corresponding value +/// to the `fir.result` Op. +void fir::DoLoopOp::resultToSourceOps( + llvm::SmallVectorImpl &results, unsigned resultNum) { + auto oper = finalValue() ? resultNum : resultNum + 1; + auto *term = region().front().getTerminator(); + if (oper < term->getNumOperands()) + results.push_back(term->getOperand(oper)); +} + +/// Translate the block argument (by index number) to the corresponding value +/// passed as an iter_arg to the parent DoLoopOp. +mlir::Value fir::DoLoopOp::blockArgToSourceOp(unsigned blockArgNum) { + if (blockArgNum > 0 && blockArgNum <= initArgs().size()) + return initArgs()[blockArgNum - 1]; + return {}; +} + //===----------------------------------------------------------------------===// // MulfOp //===----------------------------------------------------------------------===// @@ -995,7 +1156,7 @@ template static A getSubOperands(unsigned pos, A allArgs, mlir::DenseIntElementsAttr ranges, - AdditionalArgs &&... additionalArgs) { + AdditionalArgs &&...additionalArgs) { unsigned start = 0; for (unsigned i = 0; i < pos; ++i) start += (*(ranges.begin() + i)).getZExtValue(); @@ -1132,7 +1293,7 @@ return mlir::failure(); dests.push_back(dest); destArgs.push_back(destArg); - if (!parser.parseOptionalRSquare()) + if (mlir::succeeded(parser.parseOptionalRSquare())) break; if (parser.parseComma()) return mlir::failure(); @@ -1333,7 +1494,7 @@ attrs.push_back(attr); dests.push_back(dest); destArgs.push_back(destArg); - if (!parser.parseOptionalRSquare()) + if (mlir::succeeded(parser.parseOptionalRSquare())) break; if (parser.parseComma()) return mlir::failure(); @@ -1395,34 +1556,35 @@ } //===----------------------------------------------------------------------===// -// WhereOp +// IfOp //===----------------------------------------------------------------------===// -void fir::WhereOp::build(mlir::OpBuilder &builder, OperationState &result, - mlir::Value cond, bool withElseRegion) { + +void fir::IfOp::build(mlir::OpBuilder &builder, OperationState &result, + mlir::Value cond, bool withElseRegion) { build(builder, result, llvm::None, cond, withElseRegion); } -void fir::WhereOp::build(mlir::OpBuilder &builder, OperationState &result, - mlir::TypeRange resultTypes, mlir::Value cond, - bool withElseRegion) { +void fir::IfOp::build(mlir::OpBuilder &builder, OperationState &result, + mlir::TypeRange resultTypes, mlir::Value cond, + bool withElseRegion) { result.addOperands(cond); result.addTypes(resultTypes); mlir::Region *thenRegion = result.addRegion(); thenRegion->push_back(new mlir::Block()); if (resultTypes.empty()) - WhereOp::ensureTerminator(*thenRegion, builder, result.location); + IfOp::ensureTerminator(*thenRegion, builder, result.location); mlir::Region *elseRegion = result.addRegion(); if (withElseRegion) { elseRegion->push_back(new mlir::Block()); if (resultTypes.empty()) - WhereOp::ensureTerminator(*elseRegion, builder, result.location); + IfOp::ensureTerminator(*elseRegion, builder, result.location); } } -static mlir::ParseResult parseWhereOp(OpAsmParser &parser, - OperationState &result) { +static mlir::ParseResult parseIfOp(OpAsmParser &parser, + OperationState &result) { result.regions.reserve(2); mlir::Region *thenRegion = result.addRegion(); mlir::Region *elseRegion = result.addRegion(); @@ -1434,44 +1596,44 @@ parser.resolveOperand(cond, i1Type, result.operands)) return mlir::failure(); - if (parser.parseRegion(*thenRegion, {}, {})) + if (parser.parseOptionalArrowTypeList(result.types)) return mlir::failure(); - WhereOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); + if (parser.parseRegion(*thenRegion, {}, {})) + return mlir::failure(); + IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); - if (!parser.parseOptionalKeyword("else")) { + if (mlir::succeeded(parser.parseOptionalKeyword("else"))) { if (parser.parseRegion(*elseRegion, {}, {})) return mlir::failure(); - WhereOp::ensureTerminator(*elseRegion, parser.getBuilder(), - result.location); + IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location); } // Parse the optional attribute list. if (parser.parseOptionalAttrDict(result.attributes)) return mlir::failure(); - return mlir::success(); } -static LogicalResult verify(fir::WhereOp op) { - if (op.getNumResults() != 0 && op.otherRegion().empty()) +static LogicalResult verify(fir::IfOp op) { + if (op.getNumResults() != 0 && op.elseRegion().empty()) return op.emitOpError("must have an else block if defining values"); return mlir::success(); } -static void print(mlir::OpAsmPrinter &p, fir::WhereOp op) { +static void print(mlir::OpAsmPrinter &p, fir::IfOp op) { bool printBlockTerminators = false; - p << fir::WhereOp::getOperationName() << ' ' << op.condition(); + p << fir::IfOp::getOperationName() << ' ' << op.condition(); if (!op.results().empty()) { p << " -> (" << op.getResultTypes() << ')'; printBlockTerminators = true; } - p.printRegion(op.whereRegion(), /*printEntryBlockArgs=*/false, + p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/false, printBlockTerminators); // Print the 'else' regions if it exists and has a block. - auto &otherReg = op.otherRegion(); + auto &otherReg = op.elseRegion(); if (!otherReg.empty()) { p << " else"; p.printRegion(otherReg, /*printEntryBlockArgs=*/false, @@ -1480,6 +1642,16 @@ p.printOptionalAttrDict(op->getAttrs()); } +void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl &results, + unsigned resultNum) { + auto *term = thenRegion().front().getTerminator(); + if (resultNum < term->getNumOperands()) + results.push_back(term->getOperand(resultNum)); + term = elseRegion().front().getTerminator(); + if (resultNum < term->getNumOperands()) + results.push_back(term->getOperand(resultNum)); +} + //===----------------------------------------------------------------------===// mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) { @@ -1549,7 +1721,9 @@ return f; mlir::OpBuilder modBuilder(module.getBodyRegion()); modBuilder.setInsertionPoint(module.getBody()->getTerminator()); - return modBuilder.create(loc, name, type, attrs); + auto result = modBuilder.create(loc, name, type, attrs); + result.setVisibility(mlir::SymbolTable::Visibility::Private); + return result; } fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module, @@ -1558,7 +1732,9 @@ if (auto g = module.lookupSymbol(name)) return g; mlir::OpBuilder modBuilder(module.getBodyRegion()); - return modBuilder.create(loc, name, type, attrs); + auto result = modBuilder.create(loc, name, type, attrs); + result.setVisibility(mlir::SymbolTable::Visibility::Private); + return result; } // Tablegen operators Index: flang/lib/Optimizer/Dialect/FIRType.cpp =================================================================== --- flang/lib/Optimizer/Dialect/FIRType.cpp +++ flang/lib/Optimizer/Dialect/FIRType.cpp @@ -5,6 +5,10 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Dialect/FIRDialect.h" @@ -81,19 +85,45 @@ return parseTypeSingleton(parser, loc); } -// `char` `<` kind `>` +// `char` `<` kind [`,` `len`] `>` CharacterType parseCharacter(mlir::DialectAsmParser &parser) { - return parseKindSingleton(parser); + int kind = 0; + if (parser.parseLess() || parser.parseInteger(kind)) { + parser.emitError(parser.getCurrentLocation(), "kind value expected"); + return {}; + } + CharacterType::LenType len = 1; + if (mlir::succeeded(parser.parseOptionalComma())) { + if (mlir::succeeded(parser.parseOptionalQuestion())) { + len = fir::CharacterType::unknownLen(); + } else if (!mlir::succeeded(parser.parseInteger(len))) { + parser.emitError(parser.getCurrentLocation(), "len value expected"); + return {}; + } + } + if (parser.parseGreater()) + return {}; + return CharacterType::get(parser.getBuilder().getContext(), kind, len); } // `complex` `<` kind `>` -CplxType parseComplex(mlir::DialectAsmParser &parser) { - return parseKindSingleton(parser); +fir::ComplexType parseComplex(mlir::DialectAsmParser &parser) { + return parseKindSingleton(parser); +} + +// `shape` `<` rank `>` +ShapeType parseShape(mlir::DialectAsmParser &parser) { + return parseRankSingleton(parser); +} + +// `shapeshift` `<` rank `>` +ShapeShiftType parseShapeShift(mlir::DialectAsmParser &parser) { + return parseRankSingleton(parser); } -// `dims` `<` rank `>` -DimsType parseDims(mlir::DialectAsmParser &parser) { - return parseRankSingleton(parser); +// `slice` `<` rank `>` +SliceType parseSlice(mlir::DialectAsmParser &parser) { + return parseRankSingleton(parser); } // `field` @@ -107,8 +137,8 @@ } // `int` `<` kind `>` -IntType parseInteger(mlir::DialectAsmParser &parser) { - return parseKindSingleton(parser); +fir::IntegerType parseInteger(mlir::DialectAsmParser &parser) { + return parseKindSingleton(parser); } // `len` @@ -142,6 +172,19 @@ return parseTypeSingleton(parser, loc); } +// `vector` `<` len `:` type `>` +fir::VectorType parseVector(mlir::DialectAsmParser &parser, + mlir::Location loc) { + int64_t len = 0; + mlir::Type eleTy; + if (parser.parseLess() || parser.parseInteger(len) || parser.parseColon() || + parser.parseType(eleTy) || parser.parseGreater()) { + parser.emitError(parser.getNameLoc(), "invalid vector type"); + return {}; + } + return fir::VectorType::get(len, eleTy); +} + // `void` mlir::Type parseVoid(mlir::DialectAsmParser &parser) { return parser.getBuilder().getNoneType(); @@ -156,7 +199,7 @@ } SequenceType::Shape shape; if (parser.parseOptionalStar()) { - if (parser.parseDimensionList(shape, true)) { + if (parser.parseDimensionList(shape, /*allowDynamic=*/true)) { parser.emitError(parser.getNameLoc(), "invalid shape"); return {}; } @@ -181,14 +224,15 @@ /// Is `ty` a standard or FIR integer type? static bool isaIntegerType(mlir::Type ty) { // TODO: why aren't we using isa_integer? investigatation required. - return ty.isa() || ty.isa(); + return ty.isa() || ty.isa(); } bool verifyRecordMemberType(mlir::Type ty) { return !(ty.isa() || ty.isa() || - ty.isa() || ty.isa() || ty.isa() || - ty.isa() || ty.isa() || - ty.isa()); + ty.isa() || ty.isa() || + ty.isa() || ty.isa() || + ty.isa() || ty.isa() || + ty.isa() || ty.isa()); } bool verifySameLists(llvm::ArrayRef a1, @@ -325,8 +369,6 @@ return parseCharacter(parser); if (typeNameLit == "complex") return parseComplex(parser); - if (typeNameLit == "dims") - return parseDims(parser); if (typeNameLit == "field") return parseField(parser); if (typeNameLit == "heap") @@ -343,12 +385,20 @@ return parseReal(parser); if (typeNameLit == "ref") return parseReference(parser, loc); + if (typeNameLit == "shape") + return parseShape(parser); + if (typeNameLit == "shapeshift") + return parseShapeShift(parser); + if (typeNameLit == "slice") + return parseSlice(parser); if (typeNameLit == "tdesc") return parseTypeDesc(parser, loc); if (typeNameLit == "type") return parseDerived(parser, loc); if (typeNameLit == "void") return parseVoid(parser); + if (typeNameLit == "vector") + return parseVector(parser, loc); parser.emitError(parser.getNameLoc(), "unknown FIR type " + typeNameLit); return {}; @@ -361,39 +411,92 @@ /// `CHARACTER` storage struct CharacterTypeStorage : public mlir::TypeStorage { - using KeyTy = KindTy; + using KeyTy = std::tuple; - static unsigned hashKey(const KeyTy &key) { return llvm::hash_combine(key); } + static unsigned hashKey(const KeyTy &key) { + auto hashVal = llvm::hash_combine(std::get<0>(key)); + return llvm::hash_combine(hashVal, llvm::hash_combine(std::get<1>(key))); + } - bool operator==(const KeyTy &key) const { return key == getFKind(); } + bool operator==(const KeyTy &key) const { + return key == KeyTy{getFKind(), getLen()}; + } static CharacterTypeStorage *construct(mlir::TypeStorageAllocator &allocator, - KindTy kind) { + const KeyTy &key) { auto *storage = allocator.allocate(); - return new (storage) CharacterTypeStorage{kind}; + return new (storage) + CharacterTypeStorage{std::get<0>(key), std::get<1>(key)}; } KindTy getFKind() const { return kind; } + CharacterType::LenType getLen() const { return len; } protected: KindTy kind; + CharacterType::LenType len; private: CharacterTypeStorage() = delete; - explicit CharacterTypeStorage(KindTy kind) : kind{kind} {} + explicit CharacterTypeStorage(KindTy kind, CharacterType::LenType len) + : kind{kind}, len{len} {} +}; + +struct ShapeTypeStorage : public mlir::TypeStorage { + using KeyTy = unsigned; + + static unsigned hashKey(const KeyTy &key) { return llvm::hash_combine(key); } + + bool operator==(const KeyTy &key) const { return key == getRank(); } + + static ShapeTypeStorage *construct(mlir::TypeStorageAllocator &allocator, + unsigned rank) { + auto *storage = allocator.allocate(); + return new (storage) ShapeTypeStorage{rank}; + } + + unsigned getRank() const { return rank; } + +protected: + unsigned rank; + +private: + ShapeTypeStorage() = delete; + explicit ShapeTypeStorage(unsigned rank) : rank{rank} {} }; +struct ShapeShiftTypeStorage : public mlir::TypeStorage { + using KeyTy = unsigned; + + static unsigned hashKey(const KeyTy &key) { return llvm::hash_combine(key); } + + bool operator==(const KeyTy &key) const { return key == getRank(); } -struct DimsTypeStorage : public mlir::TypeStorage { + static ShapeShiftTypeStorage *construct(mlir::TypeStorageAllocator &allocator, + unsigned rank) { + auto *storage = allocator.allocate(); + return new (storage) ShapeShiftTypeStorage{rank}; + } + + unsigned getRank() const { return rank; } + +protected: + unsigned rank; + +private: + ShapeShiftTypeStorage() = delete; + explicit ShapeShiftTypeStorage(unsigned rank) : rank{rank} {} +}; +struct SliceTypeStorage : public mlir::TypeStorage { using KeyTy = unsigned; static unsigned hashKey(const KeyTy &key) { return llvm::hash_combine(key); } bool operator==(const KeyTy &key) const { return key == getRank(); } - static DimsTypeStorage *construct(mlir::TypeStorageAllocator &allocator, - unsigned rank) { - auto *storage = allocator.allocate(); - return new (storage) DimsTypeStorage{rank}; + static SliceTypeStorage *construct(mlir::TypeStorageAllocator &allocator, + unsigned rank) { + auto *storage = allocator.allocate(); + return new (storage) SliceTypeStorage{rank}; } unsigned getRank() const { return rank; } @@ -402,8 +505,8 @@ unsigned rank; private: - DimsTypeStorage() = delete; - explicit DimsTypeStorage(unsigned rank) : rank{rank} {} + SliceTypeStorage() = delete; + explicit SliceTypeStorage(unsigned rank) : rank{rank} {} }; /// The type of a derived type part reference @@ -469,17 +572,17 @@ }; /// `INTEGER` storage -struct IntTypeStorage : public mlir::TypeStorage { +struct IntegerTypeStorage : public mlir::TypeStorage { using KeyTy = KindTy; static unsigned hashKey(const KeyTy &key) { return llvm::hash_combine(key); } bool operator==(const KeyTy &key) const { return key == getFKind(); } - static IntTypeStorage *construct(mlir::TypeStorageAllocator &allocator, - KindTy kind) { - auto *storage = allocator.allocate(); - return new (storage) IntTypeStorage{kind}; + static IntegerTypeStorage *construct(mlir::TypeStorageAllocator &allocator, + KindTy kind) { + auto *storage = allocator.allocate(); + return new (storage) IntegerTypeStorage{kind}; } KindTy getFKind() const { return kind; } @@ -488,22 +591,22 @@ KindTy kind; private: - IntTypeStorage() = delete; - explicit IntTypeStorage(KindTy kind) : kind{kind} {} + IntegerTypeStorage() = delete; + explicit IntegerTypeStorage(KindTy kind) : kind{kind} {} }; /// `COMPLEX` storage -struct CplxTypeStorage : public mlir::TypeStorage { +struct ComplexTypeStorage : public mlir::TypeStorage { using KeyTy = KindTy; static unsigned hashKey(const KeyTy &key) { return llvm::hash_combine(key); } bool operator==(const KeyTy &key) const { return key == getFKind(); } - static CplxTypeStorage *construct(mlir::TypeStorageAllocator &allocator, - KindTy kind) { - auto *storage = allocator.allocate(); - return new (storage) CplxTypeStorage{kind}; + static ComplexTypeStorage *construct(mlir::TypeStorageAllocator &allocator, + KindTy kind) { + auto *storage = allocator.allocate(); + return new (storage) ComplexTypeStorage{kind}; } KindTy getFKind() const { return kind; } @@ -512,8 +615,8 @@ KindTy kind; private: - CplxTypeStorage() = delete; - explicit CplxTypeStorage(KindTy kind) : kind{kind} {} + ComplexTypeStorage() = delete; + explicit ComplexTypeStorage(KindTy kind) : kind{kind} {} }; /// `REAL` storage (for reals of unsupported sizes) @@ -591,9 +694,9 @@ KindTy getFKind() const { return kind; } - // a !fir.boxchar always wraps a !fir.char + // a !fir.boxchar always wraps a !fir.char CharacterType getElementType(mlir::MLIRContext *ctxt) const { - return CharacterType::get(ctxt, getFKind()); + return CharacterType::getUnknownLen(ctxt, getFKind()); } protected: @@ -710,7 +813,7 @@ std::tuple; static unsigned hashKey(const KeyTy &key) { - auto shapeHash{hash_value(std::get(key))}; + auto shapeHash = hash_value(std::get(key)); shapeHash = llvm::hash_combine(shapeHash, std::get(key)); return llvm::hash_combine(shapeHash, std::get(key)); } @@ -816,6 +919,39 @@ explicit TypeDescTypeStorage(mlir::Type ofTy) : ofTy{ofTy} {} }; +/// Vector type storage +struct VectorTypeStorage : public mlir::TypeStorage { + using KeyTy = std::tuple; + + static unsigned hashKey(const KeyTy &key) { + return llvm::hash_combine(std::get(key), + std::get(key)); + } + + bool operator==(const KeyTy &key) const { + return key == KeyTy{getLen(), getEleTy()}; + } + + static VectorTypeStorage *construct(mlir::TypeStorageAllocator &allocator, + const KeyTy &key) { + auto *storage = allocator.allocate(); + return new (storage) + VectorTypeStorage{std::get(key), std::get(key)}; + } + + uint64_t getLen() const { return len; } + mlir::Type getEleTy() const { return eleTy; } + +protected: + uint64_t len; + mlir::Type eleTy; + +private: + VectorTypeStorage() = delete; + explicit VectorTypeStorage(uint64_t len, mlir::Type eleTy) + : len{len}, eleTy{eleTy} {} +}; + } // namespace detail template @@ -833,8 +969,8 @@ bool isa_fir_or_std_type(mlir::Type t) { if (auto funcType = t.dyn_cast()) - return llvm::all_of(funcType.getInputs(), isa_fir_or_std_type) && - llvm::all_of(funcType.getResults(), isa_fir_or_std_type); + return llvm::all_of(funcType.getInputs(), isa_fir_or_std_type) && + llvm::all_of(funcType.getResults(), isa_fir_or_std_type); return isa_fir_type(t) || isa_std_type(t); } @@ -847,11 +983,13 @@ } bool isa_passbyref_type(mlir::Type t) { - return t.isa() || isa_box_type(t); + return t.isa() || isa_box_type(t) || + t.isa(); } bool isa_aggregate(mlir::Type t) { - return t.isa() || t.isa(); + return t.isa() || t.isa() || + t.isa(); } mlir::Type dyn_cast_ptrEleTy(mlir::Type t) { @@ -865,20 +1003,17 @@ // CHARACTER -CharacterType fir::CharacterType::get(mlir::MLIRContext *ctxt, KindTy kind) { - return Base::get(ctxt, kind); +CharacterType fir::CharacterType::get(mlir::MLIRContext *ctxt, KindTy kind, + CharacterType::LenType len) { + return Base::get(ctxt, kind, len); } -int fir::CharacterType::getFKind() const { return getImpl()->getFKind(); } +KindTy fir::CharacterType::getFKind() const { return getImpl()->getFKind(); } -// Dims - -DimsType fir::DimsType::get(mlir::MLIRContext *ctxt, unsigned rank) { - return Base::get(ctxt, rank); +CharacterType::LenType fir::CharacterType::getLen() const { + return getImpl()->getLen(); } -unsigned fir::DimsType::getRank() const { return getImpl()->getRank(); } - // Field FieldType fir::FieldType::get(mlir::MLIRContext *ctxt) { @@ -897,27 +1032,27 @@ return Base::get(ctxt, kind); } -int fir::LogicalType::getFKind() const { return getImpl()->getFKind(); } +KindTy fir::LogicalType::getFKind() const { return getImpl()->getFKind(); } // INTEGER -IntType fir::IntType::get(mlir::MLIRContext *ctxt, KindTy kind) { +fir::IntegerType fir::IntegerType::get(mlir::MLIRContext *ctxt, KindTy kind) { return Base::get(ctxt, kind); } -int fir::IntType::getFKind() const { return getImpl()->getFKind(); } +KindTy fir::IntegerType::getFKind() const { return getImpl()->getFKind(); } // COMPLEX -CplxType fir::CplxType::get(mlir::MLIRContext *ctxt, KindTy kind) { +fir::ComplexType fir::ComplexType::get(mlir::MLIRContext *ctxt, KindTy kind) { return Base::get(ctxt, kind); } -mlir::Type fir::CplxType::getElementType() const { +mlir::Type fir::ComplexType::getElementType() const { return fir::RealType::get(getContext(), getFKind()); } -KindTy fir::CplxType::getFKind() const { return getImpl()->getFKind(); } +KindTy fir::ComplexType::getFKind() const { return getImpl()->getFKind(); } // REAL @@ -925,7 +1060,7 @@ return Base::get(ctxt, kind); } -int fir::RealType::getFKind() const { return getImpl()->getFKind(); } +KindTy fir::RealType::getFKind() const { return getImpl()->getFKind(); } // Box @@ -992,8 +1127,10 @@ mlir::LogicalResult fir::ReferenceType::verifyConstructionInvariants(mlir::Location loc, mlir::Type eleTy) { - if (eleTy.isa() || eleTy.isa() || eleTy.isa() || - eleTy.isa() || eleTy.isa()) + if (eleTy.isa() || eleTy.isa() || + eleTy.isa() || eleTy.isa() || + eleTy.isa() || eleTy.isa() || + eleTy.isa()) return mlir::emitError(loc, "cannot build a reference to type: ") << eleTy << '\n'; return mlir::success(); @@ -1012,7 +1149,8 @@ static bool canBePointerOrHeapElementType(mlir::Type eleTy) { return eleTy.isa() || eleTy.isa() || - eleTy.isa() || eleTy.isa() || + eleTy.isa() || eleTy.isa() || + eleTy.isa() || eleTy.isa() || eleTy.isa() || eleTy.isa() || eleTy.isa() || eleTy.isa() || eleTy.isa() || eleTy.isa(); @@ -1100,10 +1238,12 @@ mlir::AffineMapAttr map) { // DIMENSION attribute can only be applied to an intrinsic or record type if (eleTy.isa() || eleTy.isa() || - eleTy.isa() || eleTy.isa() || + eleTy.isa() || eleTy.isa() || + eleTy.isa() || eleTy.isa() || eleTy.isa() || eleTy.isa() || eleTy.isa() || eleTy.isa() || eleTy.isa() || - eleTy.isa() || eleTy.isa()) + eleTy.isa() || eleTy.isa() || + eleTy.isa()) return mlir::emitError(loc, "cannot build an array of this element type: ") << eleTy << '\n'; return mlir::success(); @@ -1129,6 +1269,31 @@ return llvm::hash_combine(0); } +// Shape + +ShapeType fir::ShapeType::get(mlir::MLIRContext *ctxt, unsigned rank) { + return Base::get(ctxt, rank); +} + +unsigned fir::ShapeType::getRank() const { return getImpl()->getRank(); } + +// Shapeshift + +ShapeShiftType fir::ShapeShiftType::get(mlir::MLIRContext *ctxt, + unsigned rank) { + return Base::get(ctxt, rank); +} + +unsigned fir::ShapeShiftType::getRank() const { return getImpl()->getRank(); } + +// Slice + +SliceType fir::SliceType::get(mlir::MLIRContext *ctxt, unsigned rank) { + return Base::get(ctxt, rank); +} + +unsigned fir::SliceType::getRank() const { return getImpl()->getRank(); } + /// RecordType /// /// This type captures a Fortran "derived type" @@ -1171,9 +1336,9 @@ return {}; } -/// Type descriptor type -/// -/// This is the type of a type descriptor object (similar to a class instance) +//===----------------------------------------------------------------------===// +// Type descriptor type +//===----------------------------------------------------------------------===// TypeDescType fir::TypeDescType::get(mlir::Type ofType) { assert(!ofType.isa()); @@ -1186,7 +1351,8 @@ fir::TypeDescType::verifyConstructionInvariants(mlir::Location loc, mlir::Type eleTy) { if (eleTy.isa() || eleTy.isa() || - eleTy.isa() || eleTy.isa() || + eleTy.isa() || eleTy.isa() || + eleTy.isa() || eleTy.isa() || eleTy.isa() || eleTy.isa() || eleTy.isa() || eleTy.isa()) return mlir::emitError(loc, "cannot build a type descriptor of type: ") @@ -1194,16 +1360,36 @@ return mlir::success(); } +//===----------------------------------------------------------------------===// +// Vector type +//===----------------------------------------------------------------------===// + +fir::VectorType fir::VectorType::get(uint64_t len, mlir::Type eleTy) { + return Base::get(eleTy.getContext(), len, eleTy); +} + +mlir::Type fir::VectorType::getEleTy() const { return getImpl()->getEleTy(); } + +uint64_t fir::VectorType::getLen() const { return getImpl()->getLen(); } + +mlir::LogicalResult +fir::VectorType::verifyConstructionInvariants(mlir::Location loc, uint64_t len, + mlir::Type eleTy) { + if (!(fir::isa_real(eleTy) || fir::isa_integer(eleTy))) + return mlir::emitError(loc, "cannot build a vector of type ") + << eleTy << '\n'; + return mlir::success(); +} + namespace { void printBounds(llvm::raw_ostream &os, const SequenceType::Shape &bounds) { os << '<'; for (auto &b : bounds) { - if (b >= 0) { + if (b >= 0) os << b << 'x'; - } else { + else os << "?x"; - } } } @@ -1241,15 +1427,27 @@ os << '>'; return; } - if (auto type = ty.dyn_cast()) { - os << "char<" << type.getFKind() << '>'; + if (auto chTy = ty.dyn_cast()) { + // Fortran intrinsic type CHARACTER + os << "char<" << chTy.getFKind(); + auto len = chTy.getLen(); + if (len != fir::CharacterType::singleton()) { + os << ','; + if (len == fir::CharacterType::unknownLen()) + os << '?'; + else + os << len; + } + os << '>'; return; } - if (auto type = ty.dyn_cast()) { + if (auto type = ty.dyn_cast()) { + // Fortran intrinsic type COMPLEX os << "complex<" << type.getFKind() << '>'; return; } if (auto type = ty.dyn_cast()) { + // Fortran derived type os << "type<" << type.getName(); if (!recordTypeVisited.count(type.uniqueKey())) { recordTypeVisited.insert(type.uniqueKey()); @@ -1276,8 +1474,16 @@ os << '>'; return; } - if (auto type = ty.dyn_cast()) { - os << "dims<" << type.getRank() << '>'; + if (auto type = ty.dyn_cast()) { + os << "shape<" << type.getRank() << '>'; + return; + } + if (auto type = ty.dyn_cast()) { + os << "shapeshift<" << type.getRank() << '>'; + return; + } + if (auto type = ty.dyn_cast()) { + os << "slice<" << type.getRank() << '>'; return; } if (ty.isa()) { @@ -1290,7 +1496,8 @@ os << '>'; return; } - if (auto type = ty.dyn_cast()) { + if (auto type = ty.dyn_cast()) { + // Fortran intrinsic type INTEGER os << "int<" << type.getFKind() << '>'; return; } @@ -1299,6 +1506,7 @@ return; } if (auto type = ty.dyn_cast()) { + // Fortran intrinsic type LOGICAL os << "logical<" << type.getFKind() << '>'; return; } @@ -1309,6 +1517,7 @@ return; } if (auto type = ty.dyn_cast()) { + // Fortran intrinsic types REAL and DOUBLE PRECISION os << "real<" << type.getFKind() << '>'; return; } @@ -1340,4 +1549,24 @@ os << '>'; return; } + if (auto type = ty.dyn_cast()) { + os << "vector<" << type.getLen() << ':'; + p.printType(type.getEleTy()); + os << '>'; + return; + } +} + +bool fir::isa_unknown_size_box(mlir::Type t) { + if (auto boxTy = t.dyn_cast()) { + auto eleTy = boxTy.getEleTy(); + if (auto actualEleTy = fir::dyn_cast_ptrEleTy(eleTy)) + eleTy = actualEleTy; + if (eleTy.isa()) + return true; + if (auto seqTy = eleTy.dyn_cast()) + if (seqTy.hasUnknownShape()) + return true; + } + return false; } Index: flang/test/Fir/fir-ops.fir =================================================================== --- flang/test/Fir/fir-ops.fir +++ flang/test/Fir/fir-ops.fir @@ -1,7 +1,6 @@ // Test the FIR operations // RUN: tco -emit-fir %s | tco -emit-fir | FileCheck %s -// UNSUPPORTED: !fir // CHECK-LABEL: func private @it1() -> !fir.int<4> // CHECK: func private @box1() -> !fir.boxchar<2> @@ -98,12 +97,12 @@ %23 = fir.extract_value %22, %21 : (!fir.type, !fir.field) -> f32 // CHECK: [[VAL_26:%.*]] = constant 1 : i32 -// CHECK: [[VAL_27:%.*]] = fir.gendims [[VAL_26]], [[VAL_21]], [[VAL_26]] : (i32, i32, i32) -> !fir.dims<1> +// CHECK: [[VAL_27:%.*]] = fir.shape [[VAL_21]] : (i32) -> !fir.shape<1> // CHECK: [[VAL_28:%.*]] = constant 1.0 // CHECK: [[VAL_29:%.*]] = fir.insert_value [[VAL_24]], [[VAL_28]], [[VAL_23]] : (!fir.type, f32, !fir.field) -> !fir.type // CHECK: [[VAL_30:%.*]] = fir.len_param_index f, !fir.type %c1 = constant 1 : i32 - %24 = fir.gendims %c1, %19, %c1 : (i32, i32, i32) -> !fir.dims<1> + %24 = fir.shape %19 : (i32) -> !fir.shape<1> %cf1 = constant 1.0 : f32 %25 = fir.insert_value %22, %cf1, %21 : (!fir.type, f32, !fir.field) -> !fir.type %26 = fir.len_param_index f, !fir.type @@ -143,7 +142,7 @@ // CHECK: [[VAL_40:%.*]] = fir.alloca !fir.char<1> // CHECK: [[VAL_41:%.*]] = fir.alloca tuple // CHECK: [[VAL_42:%.*]] = fir.embox [[VAL_38]] : (!fir.ref) -> !fir.box -// CHECK: [[VAL_43:%.*]]:6 = fir.unbox [[VAL_42]] : (!fir.box) -> (!fir.ref, i32, i32, !fir.tdesc, i32, !fir.dims<0>) +// CHECK: [[VAL_43:%.*]]:6 = fir.unbox [[VAL_42]] : (!fir.box) -> (!fir.ref, i32, i32, !fir.tdesc, i32, !fir.array<3x?xindex>) // CHECK: [[VAL_44:%.*]] = constant 8 : i32 // CHECK: [[VAL_45:%.*]] = fir.undefined !fir.char<1> // CHECK: [[VAL_46:%.*]] = fir.emboxchar [[VAL_40]], [[VAL_44]] : (!fir.ref>, i32) -> !fir.boxchar<1> @@ -169,7 +168,7 @@ %d3 = fir.alloca !fir.char<1> %e6 = fir.alloca tuple %1 = fir.embox %0 : (!fir.ref) -> !fir.box - %2:6 = fir.unbox %1 : (!fir.box) -> (!fir.ref,i32,i32,!fir.tdesc,i32,!fir.dims<0>) + %2:6 = fir.unbox %1 : (!fir.box) -> (!fir.ref,i32,i32,!fir.tdesc,i32,!fir.array<3x?xindex>) %c8 = constant 8 : i32 %3 = fir.undefined !fir.char<1> %4 = fir.emboxchar %d3, %c8 : (!fir.ref>, i32) -> !fir.boxchar<1> @@ -572,12 +571,12 @@ return %5 : !fir.complex<16> } -// CHECK-LABEL: func @character_literal() -> !fir.array<13x!fir.char<1>> { -func @character_literal() -> !fir.array<13 x !fir.char<1>> { -// CHECK: [[VAL_185:%.*]] = fir.string_lit "Hello, World!"(13) : !fir.char<1> +// CHECK-LABEL: func @character_literal() -> !fir.char<1,13> { +func @character_literal() -> !fir.char<1,13> { +// CHECK: [[VAL_185:%.*]] = fir.string_lit "Hello, World!"(13) : !fir.char<1,13> %0 = fir.string_lit "Hello, World!"(13) : !fir.char<1> -// CHECK: return [[VAL_185]] : !fir.array<13x!fir.char<1>> - return %0 : !fir.array<13 x !fir.char<1>> +// CHECK: return [[VAL_185]] : !fir.char<1,13> + return %0 : !fir.char<1,13> // CHECK: } } @@ -592,7 +591,7 @@ %c1 = constant 1 : index %c100 = constant 100 : index -// CHECK: [[VAL_190:%.*]], [[VAL_191:%.*]] = fir.iterate_while ([[VAL_192:%.*]] = [[VAL_188]] to [[VAL_189]] step [[VAL_188]]) and ([[VAL_193:%.*]] = [[VAL_186]]) iter_args([[VAL_194:%.*]] = [[VAL_187]]) -> (i32) { +// CHECK: %[[VAL_190:.*]]:2 = fir.iterate_while ([[VAL_192:%.*]] = [[VAL_188]] to [[VAL_189]] step [[VAL_188]]) and ([[VAL_193:%.*]] = [[VAL_186]]) iter_args([[VAL_194:%.*]] = [[VAL_187]]) -> (i32) { // CHECK: [[VAL_195:%.*]] = call @earlyexit2([[VAL_194]]) : (i32) -> i1 // CHECK: fir.result [[VAL_195]], [[VAL_194]] : i1, i32 // CHECK: } @@ -600,7 +599,38 @@ %stop = call @earlyexit2(%v) : (i32) -> i1 fir.result %stop, %v : i1, i32 } -// CHECK: return [[VAL_190]] : i1 +// CHECK: return %[[VAL_190]]#0 : i1 // CHECK: } return %newOk#0 : i1 } + +// CHECK-LABEL: @array_access +func @array_access(%arr : !fir.ref>) { + // CHECK-DAG: %[[c1:.*]] = constant 100 + // CHECK-DAG: %[[c2:.*]] = constant 50 + %c100 = constant 100 : index + %c50 = constant 50 : index + // CHECK: %[[sh:.*]] = fir.shape %[[c1]], %[[c2]] : {{.*}} -> !fir.shape<2> + %shape = fir.shape %c100, %c50 : (index, index) -> !fir.shape<2> + %c47 = constant 47 : index + %c78 = constant 78 : index + %c3 = constant 3 : index + %c18 = constant 18 : index + %c36 = constant 36 : index + %c4 = constant 4 : index + // CHECK: %[[sl:.*]] = fir.slice {{.*}} -> !fir.slice<2> + %slice = fir.slice %c47, %c78, %c3, %c18, %c36, %c4 : (index,index,index,index,index,index) -> !fir.slice<2> + %c0 = constant 0 : index + %c99 = constant 99 : index + %c1 = constant 1 : index + fir.do_loop %i = %c0 to %c99 step %c1 { + %c49 = constant 49 : index + fir.do_loop %j = %c0 to %c49 step %c1 { + // CHECK: fir.array_coor %{{.*}}(%[[sh]]) [%[[sl]]] %{{.*}}, %{{.*}} : + %p = fir.array_coor %arr(%shape)[%slice] %i, %j : (!fir.ref>, !fir.shape<2>, !fir.slice<2>, index, index) -> !fir.ref + %x = constant 42.0 : f32 + fir.store %x to %p : !fir.ref + } + } + return +} Index: flang/test/Fir/fir-types.fir =================================================================== --- flang/test/Fir/fir-types.fir +++ flang/test/Fir/fir-types.fir @@ -1,7 +1,6 @@ // Test the FIR types // RUN: tco -emit-fir %s | tco -emit-fir | FileCheck %s -// UNSUPPORTED: !fir // Fortran Intrinsic types // CHECK-LABEL: func private @it1() -> !fir.int<4> @@ -9,11 +8,15 @@ // CHECK-LABEL: func private @it3() -> !fir.complex<8> // CHECK-LABEL: func private @it4() -> !fir.logical<1> // CHECK-LABEL: func private @it5() -> !fir.char<1> +// CHECK-LABEL: func private @it6() -> !fir.char<2,10> +// CHECK-LABEL: func private @it7() -> !fir.char<4,?> func private @it1() -> !fir.int<4> func private @it2() -> !fir.real<8> func private @it3() -> !fir.complex<8> func private @it4() -> !fir.logical<1> func private @it5() -> !fir.char<1> +func private @it6() -> !fir.char<2,10> +func private @it7() -> !fir.char<4,?> // Fortran Derived types (records) // CHECK-LABEL: func private @dvd1() -> !fir.type @@ -68,11 +71,13 @@ func private @box5() -> !fir.box> // FIR misc. types -// CHECK-LABEL: func private @oth1() -> !fir.dims<1> +// CHECK-LABEL: func private @oth1() -> !fir.shape<1> // CHECK-LABEL: func private @oth2() -> !fir.field // CHECK-LABEL: func private @oth3() -> !fir.tdesc> -// CHECK-LABEL: func private @oth4() -> !fir.dims<15> -func private @oth1() -> !fir.dims<1> +// CHECK-LABEL: func private @oth4() -> !fir.shapeshift<15> +// CHECK-LABEL: func private @oth5() -> !fir.slice<8> +func private @oth1() -> !fir.shape<1> func private @oth2() -> !fir.field func private @oth3() -> !fir.tdesc> -func private @oth4() -> !fir.dims<15> +func private @oth4() -> !fir.shapeshift<15> +func private @oth5() -> !fir.slice<8>