diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.h b/flang/include/flang/Optimizer/Dialect/FIROps.h --- a/flang/include/flang/Optimizer/Dialect/FIROps.h +++ b/flang/include/flang/Optimizer/Dialect/FIROps.h @@ -9,7 +9,9 @@ #ifndef FORTRAN_OPTIMIZER_DIALECT_FIROPS_H #define FORTRAN_OPTIMIZER_DIALECT_FIROPS_H +#include "flang/Optimizer/Dialect/FIRAttr.h" #include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Dialect/FortranVariableInterface.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Interfaces/LoopLikeInterface.h" diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -16,6 +16,9 @@ include "flang/Optimizer/Dialect/FIRDialect.td" include "flang/Optimizer/Dialect/FIRTypes.td" +include "flang/Optimizer/Dialect/FIRAttr.td" +include "flang/Optimizer/Dialect/FortranVariableInterface.td" +include "mlir/IR/BuiltinAttributes.td" // Base class for FIR operations. // All operations automatically get a prefix of "fir.". @@ -2863,4 +2866,61 @@ let results = (outs BoolLike); } +def fir_DeclareOp : fir_Op<"declare", [AttrSizedOperandSegments, + DeclareOpInterfaceMethods]> { + let summary = "declare a variable"; + + let description = [{ + Tie the properties of a Fortran variable to an address. The properties + include bounds, length parameters, and Fortran attributes. + + The memref argument describes the storage of the variable. It may be a + raw address (fir.ref), or a box or class value or address (fir.box, + fir.ref>, fir.class, fir.ref>). + + The shape argument encodes explicit extents and lower bounds. It must be + provided if the memref is the raw address of an array. + The shape argument must not be provided if memref operand is a box or + class value or address, unless the shape is a shift (encodes lower bounds) + and the memref if a box value (this covers assumed shapes with local lower + bounds). + + The typeparams values are meant to carry the non-deferred length parameters + (this includes both Fortran assumed and explicit length parameters). + It must always be provided for characters and parametrized derived types + when memref is not a box value or address. + + Example: + + CHARACTER(n), OPTIONAL, TARGET :: c(10:, 20:) + + Can be represented as: + ``` + func.func @foo(%arg0: !fir.box>>, %arg1: !fir.ref) { + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + %1 = fir.load %ag1 : fir.ref + %2 = fir.shift %c10, %c20 : (index, index) -> !fir.shift<2> + %3 = fir.declare %arg0(%2) typeparams %1 {fortran_attrs = #fir.var_attrs, uniq_name = "c"} + // ... uses %3 as "c" + } + ``` + }]; + + let arguments = (ins + AnyRefOrBox:$memref, + Optional:$shape, + Variadic:$typeparams, + Builtin_StringAttr:$uniq_name, + OptionalAttr:$fortran_attrs + ); + + let results = (outs AnyRefOrBox); + + let assemblyFormat = [{ + $memref (`(` $shape^ `)`)? (`typeparams` $typeparams^)? + attr-dict `:` functional-type(operands, results) + }]; +} + #endif diff --git a/flang/include/flang/Optimizer/Dialect/FortranVariableInterface.td b/flang/include/flang/Optimizer/Dialect/FortranVariableInterface.td --- a/flang/include/flang/Optimizer/Dialect/FortranVariableInterface.td +++ b/flang/include/flang/Optimizer/Dialect/FortranVariableInterface.td @@ -46,8 +46,8 @@ }] >, InterfaceMethod< - /*desc=*/"Get the shape of the variable", - /*retTy=*/"llvm::Optional", + /*desc=*/"Get the shape of the variable. May be a null value.", + /*retTy=*/"mlir::Value", /*methodName=*/"getShape", /*args=*/(ins), /*methodBody=*/[{}], @@ -74,7 +74,7 @@ /// Get the sequence type or scalar value type corresponding to this /// variable. mlir::Type getElementOrSequenceType() { - return fir::unwrapPassByRefType(getBase().getType()); + return fir::unwrapPassByRefType(fir::unwrapRefType(getBase().getType())); } /// Get the scalar value type corresponding to this variable. @@ -87,6 +87,17 @@ return getElementOrSequenceType().isa(); } + /// Return the rank of the entity if it is known at compile time. + llvm::Optional getRank() { + if (auto sequenceType = + getElementOrSequenceType().dyn_cast()) { + if (sequenceType.hasUnknownShape()) + return {}; + return sequenceType.getDimension(); + } + return 0; + } + /// Is this variable a Fortran pointer ? bool isPointer() { auto attrs = getFortranAttrs(); @@ -117,8 +128,32 @@ return getExplicitTypeParams()[0]; } + /// Is this variable represented as a fir.box or fir.class value ? + bool isBoxValue() { + return getBase().getType().isa(); + } + + /// Is this variable represented as a fir.box or fir.class address ? + bool isBoxAddress() { + mlir::Type type = getBase().getType(); + return fir::isa_ref_type(type) && + fir::unwrapRefType(type).isa(); + } + + /// Is this variable represented as the value or address of a fir.box or + /// fir.class ? + bool isBox() { + return fir::unwrapRefType(getBase().getType()).isa(); + } + + /// Interface verifier imlementation. + mlir::LogicalResult verifyImpl(); + }]; + let verify = [{ + return ::mlir::cast<::fir::FortranVariableOpInterface>($_op).verifyImpl(); + }]; } #endif // FORTRANVARIABLEINTERFACE diff --git a/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp b/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp --- a/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp +++ b/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp @@ -15,3 +15,52 @@ namespace fir { #include "flang/Optimizer/Dialect/FortranVariableInterface.cpp.inc" } + +mlir::LogicalResult fir::FortranVariableOpInterface::verifyImpl() { + const unsigned numExplicitTypeParams = getExplicitTypeParams().size(); + if (isCharacter()) { + if (numExplicitTypeParams > 1) + return emitOpError( + "of character entity must have at most one length parameter"); + if (numExplicitTypeParams == 0 && !isBox()) + return emitOpError("must be provided exactly one type parameter when its " + "base is a character that is not a box"); + + } else if (auto recordType = getElementType().dyn_cast()) { + if (numExplicitTypeParams < recordType.getNumLenParams() && !isBox()) + return emitOpError("must be provided all the derived type length " + "parameters when the base is not a box"); + if (numExplicitTypeParams > recordType.getNumLenParams()) + return emitOpError("has too many length parameters"); + } else if (numExplicitTypeParams != 0) { + return emitOpError("of numeric, logical, or assumed type entity must not " + "have length parameters"); + } + + if (isArray()) { + if (mlir::Value shape = getShape()) { + if (isBoxAddress()) + return emitOpError("for box address must not have a shape operand"); + unsigned shapeRank = 0; + if (auto shapeType = shape.getType().dyn_cast()) { + shapeRank = shapeType.getRank(); + } else if (auto shapeShiftType = + shape.getType().dyn_cast()) { + shapeRank = shapeShiftType.getRank(); + } else { + if (!isBoxValue()) + emitOpError("of array entity with a raw address base must have a " + "shape operand that is a shape or shapeshift"); + shapeRank = shape.getType().cast().getRank(); + } + + llvm::Optional rank = getRank(); + if (!rank || *rank != shapeRank) + return emitOpError("has conflicting shape and base operand ranks"); + } else if (!isBox()) { + emitOpError("of array entity with a raw address base must have a shape " + "operand that is a shape or shapeshift"); + } + } + return mlir::success(); +} diff --git a/flang/test/Fir/declare.fir b/flang/test/Fir/declare.fir new file mode 100644 --- /dev/null +++ b/flang/test/Fir/declare.fir @@ -0,0 +1,145 @@ +// Test fir.declare operation parse, verify (no errors), and unparse. + +// RUN: fir-opt %s | fir-opt | FileCheck %s + +func.func @numeric_declare(%arg0: !fir.ref) { + %0 = fir.declare %arg0 {uniq_name = "x"} : (!fir.ref) -> !fir.ref + return +} +// CHECK-LABEL: func.func @numeric_declare( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref) { +// CHECK: %[[VAL_1:.*]] = fir.declare %[[VAL_0]] {uniq_name = "x"} : (!fir.ref) -> !fir.ref + + +func.func @char_declare(%arg0: !fir.boxchar<1> ) { + %0:2 = fir.unboxchar %arg0 : (!fir.boxchar<1>) -> (!fir.ref>, index) + %1 = fir.declare %0#0 typeparams %0#1 {uniq_name = "c"} : (!fir.ref>, index) -> !fir.ref> + return +} +// CHECK-LABEL: func.func @char_declare( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.boxchar<1>) { +// CHECK: %[[VAL_1:.*]]:2 = fir.unboxchar %[[VAL_0]] : (!fir.boxchar<1>) -> (!fir.ref>, index) +// CHECK: %[[VAL_2:.*]] = fir.declare %[[VAL_1]]#0 typeparams %[[VAL_1]]#1 {uniq_name = "c"} : (!fir.ref>, index) -> !fir.ref> + + +func.func @derived_declare(%arg0: !fir.ref>) { + %0 = fir.declare %arg0 {uniq_name = "x"} : (!fir.ref>) -> !fir.ref> + return +} +// CHECK-LABEL: func.func @derived_declare( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>) { +// CHECK: %[[VAL_1:.*]] = fir.declare %[[VAL_0]] {uniq_name = "x"} : (!fir.ref>) -> !fir.ref> + + +func.func @pdt_declare(%arg0: !fir.ref>) { + %c1 = arith.constant 1 : index + %0 = fir.declare %arg0 typeparams %c1 {uniq_name = "x"} : (!fir.ref>, index) -> !fir.ref> + return +} +// CHECK-LABEL: func.func @pdt_declare( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>) { +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = fir.declare %[[VAL_0]] typeparams %[[VAL_1]] {uniq_name = "x"} : (!fir.ref>, index) -> !fir.ref> + + +func.func @array_declare(%arg0: !fir.ref>) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %shape = fir.shape %c1, %c2 : (index, index) -> !fir.shape<2> + %0 = fir.declare %arg0(%shape) {uniq_name = "x"} : (!fir.ref>, !fir.shape<2>) -> !fir.ref> + return +} +// CHECK-LABEL: func.func @array_declare( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>) { +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_1]], %[[VAL_2]] : (index, index) -> !fir.shape<2> +// CHECK: %[[VAL_4:.*]] = fir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = "x"} : (!fir.ref>, !fir.shape<2>) -> !fir.ref> + + +func.func @array_declare_2(%arg0: !fir.ref>) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %shape = fir.shape_shift %c1, %c2, %c3, %c4 : (index, index, index, index) -> !fir.shapeshift<2> + %0 = fir.declare %arg0(%shape) {uniq_name = "x"} : (!fir.ref>, !fir.shapeshift<2>) -> !fir.ref> + return +} +// CHECK-LABEL: func.func @array_declare_2( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>) { +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_5:.*]] = fir.shape_shift %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] : (index, index, index, index) -> !fir.shapeshift<2> +// CHECK: %[[VAL_6:.*]] = fir.declare %[[VAL_0]](%[[VAL_5]]) {uniq_name = "x"} : (!fir.ref>, !fir.shapeshift<2>) -> !fir.ref> + + +func.func @array_declare_box(%arg0: !fir.box>) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %shape = fir.shift %c1, %c2 : (index, index) -> !fir.shift<2> + %0 = fir.declare %arg0(%shape) {uniq_name = "x"} : (!fir.box>, !fir.shift<2>) -> !fir.box> + return +} +// CHECK-LABEL: func.func @array_declare_box( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.box>) { +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_3:.*]] = fir.shift %[[VAL_1]], %[[VAL_2]] : (index, index) -> !fir.shift<2> +// CHECK: %[[VAL_4:.*]] = fir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = "x"} : (!fir.box>, !fir.shift<2>) -> !fir.box> + + +func.func @array_declare_char_box(%arg0: !fir.box>>) { + %0 = fir.declare %arg0 {uniq_name = "x"} : (!fir.box>>) -> !fir.box>> + return +} +// CHECK-LABEL: func.func @array_declare_char_box( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.box>>) { +// CHECK: %[[VAL_1:.*]] = fir.declare %[[VAL_0]] {uniq_name = "x"} : (!fir.box>>) -> !fir.box>> + + +func.func @array_declare_char_box_2(%arg0: !fir.box>>) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %shape = fir.shift %c1, %c2 : (index, index) -> !fir.shift<2> + %0 = fir.declare %arg0(%shape) typeparams %c3 {uniq_name = "x"} : (!fir.box>>, !fir.shift<2>, index) -> !fir.box>> + return +} +// CHECK-LABEL: func.func @array_declare_char_box_2( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.box>>) { +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_4:.*]] = fir.shift %[[VAL_1]], %[[VAL_2]] : (index, index) -> !fir.shift<2> +// CHECK: %[[VAL_5:.*]] = fir.declare %[[VAL_0]](%[[VAL_4]]) typeparams %[[VAL_3]] {uniq_name = "x"} : (!fir.box>>, !fir.shift<2>, index) -> !fir.box>> + + +func.func @array_declare_char_boxaddr(%arg0: !fir.ref>>>>) { + %0 = fir.declare %arg0 {uniq_name = "x"} : (!fir.ref>>>>) -> !fir.ref>>>> + return +} +// CHECK-LABEL: func.func @array_declare_char_boxaddr( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>>>>) { +// CHECK: %[[VAL_1:.*]] = fir.declare %[[VAL_0]] {uniq_name = "x"} : (!fir.ref>>>>) -> !fir.ref>>>> + + +func.func @array_declare_char_boxaddr_2(%arg0: !fir.ref>>>>) { + %c3 = arith.constant 3 : index + %0 = fir.declare %arg0 typeparams %c3 {uniq_name = "x"} : (!fir.ref>>>>, index) -> !fir.ref>>>> + return +} +// CHECK-LABEL: func.func @array_declare_char_boxaddr_2( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>>>>) { +// CHECK: %[[VAL_1:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_2:.*]] = fir.declare %[[VAL_0]] typeparams %[[VAL_1]] {uniq_name = "x"} : (!fir.ref>>>>, index) -> !fir.ref>>>> + +func.func @array_declare_unlimited_polymorphic_boxaddr(%arg0: !fir.ref>>>) { + %0 = fir.declare %arg0 {uniq_name = "x"} : (!fir.ref>>>) -> !fir.ref>>> + return +} +// CHECK-LABEL: func.func @array_declare_unlimited_polymorphic_boxaddr( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>>>) { +// CHECK: %[[VAL_1:.*]] = fir.declare %[[VAL_0]] {uniq_name = "x"} : (!fir.ref>>>) -> !fir.ref>>> diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir --- a/flang/test/Fir/invalid.fir +++ b/flang/test/Fir/invalid.fir @@ -808,3 +808,123 @@ // expected-error@+1 {{Unknown fortran variable attribute: volatypo}} %0 = fir.alloca f32 {fortran_attrs = #fir.var_attrs} } + +// ----- +func.func @bad_numeric_declare(%arg0: !fir.ref) { + %c1 = arith.constant 1 : index + // expected-error@+1 {{'fir.declare' op requires attribute 'uniq_name'}} + %0 = fir.declare %arg0 typeparams %c1 {uniq_typo = "x"} : (!fir.ref, index) -> !fir.ref + return +} + +// ----- +func.func @bad_numeric_declare(%arg0: !fir.ref) { + %c1 = arith.constant 1 : index + // expected-error@+1 {{'fir.declare' op of numeric, logical, or assumed type entity must not have length parameters}} + %0 = fir.declare %arg0 typeparams %c1 {uniq_name = "x"} : (!fir.ref, index) -> !fir.ref + return +} + +// ----- +func.func @bad_char_declare(%arg0: !fir.boxchar<1> ) { + %0:2 = fir.unboxchar %arg0 : (!fir.boxchar<1>) -> (!fir.ref>, index) + // expected-error@+1 {{'fir.declare' op must be provided exactly one type parameter when its base is a character that is not a box}} + %1 = fir.declare %0#0 {uniq_name = "c"} : (!fir.ref>) -> !fir.ref> + return +} + +// ----- +func.func @bad_char_declare(%arg0: !fir.boxchar<1> ) { + %0:2 = fir.unboxchar %arg0 : (!fir.boxchar<1>) -> (!fir.ref>, index) + // expected-error@+1 {{'fir.declare' op of character entity must have at most one length parameter}} + %1 = fir.declare %0#0 typeparams %0#1, %0#1 {uniq_name = "c"} : (!fir.ref>, index, index) -> !fir.ref> + return +} + +// ----- +func.func @bad_derived_declare(%arg0: !fir.ref>) { + %c1 = arith.constant 1 : index + // expected-error@+1 {{'fir.declare' op has too many length parameters}} + %0 = fir.declare %arg0 typeparams %c1 {uniq_name = "x"} : (!fir.ref>, index) -> !fir.ref> + return +} + +// ----- +func.func @bad_pdt_declare(%arg0: !fir.ref>) { + // expected-error@+1 {{'fir.declare' op must be provided all the derived type length parameters when the base is not a box}} + %0 = fir.declare %arg0 {uniq_name = "x"} : (!fir.ref>) -> !fir.ref> + return +} + +// ----- +func.func @bad_pdt_declare_2(%arg0: !fir.ref>) { + %c1 = arith.constant 1 : index + // expected-error@+1 {{'fir.declare' op has too many length parameters}} + %0 = fir.declare %arg0 typeparams %c1, %c1 {uniq_name = "x"} : (!fir.ref>, index, index) -> !fir.ref> + return +} + + +// ----- +func.func @bad_array_declare(%arg0: !fir.ref>) { + // expected-error@+1 {{'fir.declare' op of array entity with a raw address base must have a shape operand that is a shape or shapeshift}} + %0 = fir.declare %arg0 {uniq_name = "x"} : (!fir.ref>) -> !fir.ref> + return +} + +// ----- +func.func @bad_array_declare_2(%arg0: !fir.ref>) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %shift = fir.shift %c1, %c2 : (index, index) -> !fir.shift<2> + // expected-error@+1 {{'fir.declare' op of array entity with a raw address base must have a shape operand that is a shape or shapeshift}} + %0 = fir.declare %arg0(%shift) {uniq_name = "x"} : (!fir.ref>, !fir.shift<2>) -> !fir.ref> + return +} + +// ----- +func.func @bad_array_declare_3(%arg0: !fir.ref>) { + %c1 = arith.constant 1 : index + %shape = fir.shape %c1 : (index) -> !fir.shape<1> + // expected-error@+1 {{'fir.declare' op has conflicting shape and base operand ranks}} + %0 = fir.declare %arg0(%shape) {uniq_name = "x"} : (!fir.ref>, !fir.shape<1>) -> !fir.ref> + return +} + +// ----- +func.func @bad_array_declare_4(%arg0: !fir.ref>) { + %c1 = arith.constant 1 : index + %shape = fir.shape_shift %c1, %c1 : (index, index) -> !fir.shapeshift<1> + // expected-error@+1 {{'fir.declare' op has conflicting shape and base operand ranks}} + %0 = fir.declare %arg0(%shape) {uniq_name = "x"} : (!fir.ref>, !fir.shapeshift<1>) -> !fir.ref> + return +} + +// ----- +func.func @bad_array_declare_box(%arg0: !fir.box>) { + %c1 = arith.constant 1 : index + %shape = fir.shift %c1 : (index) -> !fir.shift<1> + // expected-error@+1 {{'fir.declare' op has conflicting shape and base operand ranks}} + %0 = fir.declare %arg0(%shape) {uniq_name = "x"} : (!fir.box>, !fir.shift<1>) -> !fir.box> + return +} + +// ----- +func.func @bad_array_declare_char_boxaddr(%arg0: !fir.ref>>>>) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %shape = fir.shift %c1, %c2 : (index, index) -> !fir.shift<2> + // expected-error@+1 {{'fir.declare' op for box address must not have a shape operand}} + %0 = fir.declare %arg0(%shape) {uniq_name = "x"} : (!fir.ref>>>>, !fir.shift<2>) -> !fir.ref>>>> + return +} + +// ----- +func.func @bad_array_declare_unlimited_polymorphic_boxaddr(%arg0: !fir.ref>>>) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %shape = fir.shift %c1, %c2 : (index, index) -> !fir.shift<2> + // expected-error@+1 {{'fir.declare' op for box address must not have a shape operand}} + %0 = fir.declare %arg0(%shape) {uniq_name = "x"} : (!fir.ref>>>, !fir.shift<2>) -> !fir.ref>>> + return +} diff --git a/flang/unittests/Optimizer/CMakeLists.txt b/flang/unittests/Optimizer/CMakeLists.txt --- a/flang/unittests/Optimizer/CMakeLists.txt +++ b/flang/unittests/Optimizer/CMakeLists.txt @@ -24,6 +24,7 @@ Builder/Runtime/TransformationalTest.cpp FIRContextTest.cpp FIRTypesTest.cpp + FortranVariableTest.cpp InternalNamesTest.cpp KindMappingTest.cpp RTBuilder.cpp diff --git a/flang/unittests/Optimizer/FortranVariableTest.cpp b/flang/unittests/Optimizer/FortranVariableTest.cpp new file mode 100644 --- /dev/null +++ b/flang/unittests/Optimizer/FortranVariableTest.cpp @@ -0,0 +1,151 @@ +//===- FortranVariableTest.cpp --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gtest/gtest.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Support/InitFIR.h" + +struct FortranVariableTest : public testing::Test { +public: + void SetUp() { + fir::support::loadDialects(context); + builder = std::make_unique(&context); + mlir::Location loc = builder->getUnknownLoc(); + + // Set up a Module with a dummy function operation inside. + // Set the insertion point in the function entry block. + mlir::ModuleOp mod = builder->create(loc); + mlir::func::FuncOp func = + mlir::func::FuncOp::create(loc, "fortran_variable_tests", + builder->getFunctionType(llvm::None, llvm::None)); + auto *entryBlock = func.addEntryBlock(); + mod.push_back(mod); + builder->setInsertionPointToStart(entryBlock); + } + + mlir::Location getLoc() { return builder->getUnknownLoc(); } + mlir::Value createConstant(std::int64_t cst) { + mlir::Type indexType = builder->getIndexType(); + return builder->create( + getLoc(), indexType, builder->getIntegerAttr(indexType, cst)); + } + + mlir::Value createShape(llvm::ArrayRef extents) { + mlir::Type shapeType = fir::ShapeType::get(&context, extents.size()); + return builder->create(getLoc(), shapeType, extents); + } + mlir::MLIRContext context; + std::unique_ptr builder; +}; + +TEST_F(FortranVariableTest, SimpleScalar) { + mlir::Location loc = getLoc(); + mlir::Type eleType = mlir::FloatType::getF32(&context); + mlir::Value addr = builder->create(loc, eleType); + auto name = mlir::StringAttr::get(&context, "x"); + auto declare = builder->create(loc, addr.getType(), addr, + /*shape=*/mlir::Value{}, /*typeParams=*/llvm::None, name, + /*fortran_attrs=*/fir::FortranVariableFlagsAttr{}); + + fir::FortranVariableOpInterface fortranVariable = declare; + EXPECT_FALSE(fortranVariable.isArray()); + EXPECT_FALSE(fortranVariable.isCharacter()); + EXPECT_FALSE(fortranVariable.isPointer()); + EXPECT_FALSE(fortranVariable.isAllocatable()); + EXPECT_FALSE(fortranVariable.hasExplicitCharLen()); + EXPECT_EQ(fortranVariable.getElementType(), eleType); + EXPECT_EQ(fortranVariable.getElementOrSequenceType(), + fortranVariable.getElementType()); + EXPECT_NE(fortranVariable.getBase(), addr); + EXPECT_EQ(fortranVariable.getBase().getType(), addr.getType()); +} + +TEST_F(FortranVariableTest, CharacterScalar) { + mlir::Location loc = getLoc(); + mlir::Type eleType = fir::CharacterType::getUnknownLen(&context, 4); + mlir::Value len = createConstant(42); + llvm::SmallVector typeParams{len}; + mlir::Value addr = builder->create( + loc, eleType, /*pinned=*/false, typeParams); + auto name = mlir::StringAttr::get(&context, "x"); + auto declare = builder->create(loc, addr.getType(), addr, + /*shape=*/mlir::Value{}, typeParams, name, + /*fortran_attrs=*/fir::FortranVariableFlagsAttr{}); + + fir::FortranVariableOpInterface fortranVariable = declare; + EXPECT_FALSE(fortranVariable.isArray()); + EXPECT_TRUE(fortranVariable.isCharacter()); + EXPECT_FALSE(fortranVariable.isPointer()); + EXPECT_FALSE(fortranVariable.isAllocatable()); + EXPECT_TRUE(fortranVariable.hasExplicitCharLen()); + EXPECT_EQ(fortranVariable.getElementType(), eleType); + EXPECT_EQ(fortranVariable.getElementOrSequenceType(), + fortranVariable.getElementType()); + EXPECT_NE(fortranVariable.getBase(), addr); + EXPECT_EQ(fortranVariable.getBase().getType(), addr.getType()); + EXPECT_EQ(fortranVariable.getExplicitCharLen(), len); +} + +TEST_F(FortranVariableTest, SimpleArray) { + mlir::Location loc = getLoc(); + mlir::Type eleType = mlir::FloatType::getF32(&context); + llvm::SmallVector extents{ + createConstant(10), createConstant(20), createConstant(30)}; + fir::SequenceType::Shape typeShape( + extents.size(), fir::SequenceType::getUnknownExtent()); + mlir::Type seqTy = fir::SequenceType::get(typeShape, eleType); + mlir::Value addr = builder->create( + loc, seqTy, /*pinned=*/false, /*typeParams=*/llvm::None, extents); + mlir::Value shape = createShape(extents); + auto name = mlir::StringAttr::get(&context, "x"); + auto declare = builder->create(loc, addr.getType(), addr, + shape, /*typeParams*/ llvm::None, name, + /*fortran_attrs=*/fir::FortranVariableFlagsAttr{}); + + fir::FortranVariableOpInterface fortranVariable = declare; + EXPECT_TRUE(fortranVariable.isArray()); + EXPECT_FALSE(fortranVariable.isCharacter()); + EXPECT_FALSE(fortranVariable.isPointer()); + EXPECT_FALSE(fortranVariable.isAllocatable()); + EXPECT_FALSE(fortranVariable.hasExplicitCharLen()); + EXPECT_EQ(fortranVariable.getElementType(), eleType); + EXPECT_EQ(fortranVariable.getElementOrSequenceType(), seqTy); + EXPECT_NE(fortranVariable.getBase(), addr); + EXPECT_EQ(fortranVariable.getBase().getType(), addr.getType()); +} + +TEST_F(FortranVariableTest, CharacterArray) { + mlir::Location loc = getLoc(); + mlir::Type eleType = fir::CharacterType::getUnknownLen(&context, 4); + mlir::Value len = createConstant(42); + llvm::SmallVector typeParams{len}; + llvm::SmallVector extents{ + createConstant(10), createConstant(20), createConstant(30)}; + fir::SequenceType::Shape typeShape( + extents.size(), fir::SequenceType::getUnknownExtent()); + mlir::Type seqTy = fir::SequenceType::get(typeShape, eleType); + mlir::Value addr = builder->create( + loc, seqTy, /*pinned=*/false, typeParams, extents); + mlir::Value shape = createShape(extents); + auto name = mlir::StringAttr::get(&context, "x"); + auto declare = builder->create(loc, addr.getType(), addr, + shape, typeParams, name, + /*fortran_attrs=*/fir::FortranVariableFlagsAttr{}); + + fir::FortranVariableOpInterface fortranVariable = declare; + EXPECT_TRUE(fortranVariable.isArray()); + EXPECT_TRUE(fortranVariable.isCharacter()); + EXPECT_FALSE(fortranVariable.isPointer()); + EXPECT_FALSE(fortranVariable.isAllocatable()); + EXPECT_TRUE(fortranVariable.hasExplicitCharLen()); + EXPECT_EQ(fortranVariable.getElementType(), eleType); + EXPECT_EQ(fortranVariable.getElementOrSequenceType(), seqTy); + EXPECT_NE(fortranVariable.getBase(), addr); + EXPECT_EQ(fortranVariable.getBase().getType(), addr.getType()); + EXPECT_EQ(fortranVariable.getExplicitCharLen(), len); +}