diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td --- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td @@ -251,6 +251,51 @@ let hasVerifier = 1; } +def hlfir_ParentComponentOp : hlfir_Op<"parent_comp", [AttrSizedOperandSegments, + DeclareOpInterfaceMethods]> { + let summary = "Designate the parent component of a variable"; + + let description = [{ + This operation represents a Fortran component reference where the + component name is a parent type of the variable's derived type. + These component references cannot be represented with an hlfir.designate + because the parent type names are not embedded in fir.type<> types + as opposed to the actual component names. + + The operands are as follow: + - memref is a derived type variable whose parent component is being + designated. + - shape is the shape of memref and the result and must be provided if + memref is an array. Parent component reference lower bounds are ones, + so the provided shape must be a fir.shape. + - typeparams are the type parameters of the parent component type if any. + It is a subset of memref type parameters. + The parent component type and name is reflected in the result type. + }]; + + let arguments = (ins AnyFortranVariable:$memref, + Optional:$shape, + Variadic:$typeparams); + + let extraClassDeclaration = [{ + // Implement FortranVariableInterface interface. Parent components have + // no attributes (pointer, allocatable or contiguous can only be added + // to regular components). + std::optional getFortranAttrs() const { + return std::nullopt; + } + }]; + + let results = (outs AnyFortranVariable); + + let assemblyFormat = [{ + $memref (`shape` $shape^)? (`typeparams` $typeparams^)? + attr-dict `:` functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + def hlfir_ConcatOp : hlfir_Op<"concat", []> { let summary = "concatenate characters"; let description = [{ diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -380,6 +380,61 @@ return mlir::success(); } +//===----------------------------------------------------------------------===// +// ParentComponentOp +//===----------------------------------------------------------------------===// + +mlir::LogicalResult hlfir::ParentComponentOp::verify() { + mlir::Type baseType = + hlfir::getFortranElementOrSequenceType(getMemref().getType()); + auto maybeInputSeqType = baseType.dyn_cast(); + unsigned inputTypeRank = + maybeInputSeqType ? maybeInputSeqType.getDimension() : 0; + unsigned shapeRank = 0; + if (mlir::Value shape = getShape()) + if (auto shapeType = shape.getType().dyn_cast()) + shapeRank = shapeType.getRank(); + if (inputTypeRank != shapeRank) + return emitOpError( + "must be provided a shape if and only if the base is an array"); + mlir::Type outputBaseType = hlfir::getFortranElementOrSequenceType(getType()); + auto maybeOutputSeqType = outputBaseType.dyn_cast(); + unsigned outputTypeRank = + maybeOutputSeqType ? maybeOutputSeqType.getDimension() : 0; + if (inputTypeRank != outputTypeRank) + return emitOpError("result type rank must match input type rank"); + if (maybeOutputSeqType && maybeInputSeqType) + for (auto [inputDim, outputDim] : + llvm::zip(maybeInputSeqType.getShape(), maybeOutputSeqType.getShape())) + if (inputDim != fir::SequenceType::getUnknownExtent() && + outputDim != fir::SequenceType::getUnknownExtent()) + if (inputDim != outputDim) + return emitOpError( + "result type extents are inconsistent with memref type"); + fir::RecordType baseRecType = + hlfir::getFortranElementType(baseType).dyn_cast(); + fir::RecordType outRecType = + hlfir::getFortranElementType(outputBaseType).dyn_cast(); + if (!baseRecType || !outRecType) + return emitOpError("result type and input type must be derived types"); + + // Note: result should not be a fir.class: its dynamic type is being set to + // the parent type and allowing fir.class would break the operation codegen: + // it would keep the input dynamic type. + if (getType().isa()) + return emitOpError("result type must not be polymorphic"); + + // The array results are known to not be dis-contiguous in most cases (the + // exception being if the parent type was extended by a type without any + // components): require a fir.box to be used for the result to carry the + // strides. + if (!getType().isa() && + (outputTypeRank != 0 || fir::isRecordWithTypeParameters(outRecType))) + return emitOpError("result type must be a fir.box if the result is an " + "array or has length parameters"); + return mlir::success(); +} + //===----------------------------------------------------------------------===// // ConcatOp //===----------------------------------------------------------------------===// diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp --- a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp @@ -506,6 +506,54 @@ } }; +class ParentComponentOpConversion + : public mlir::OpRewritePattern { +public: + explicit ParentComponentOpConversion(mlir::MLIRContext *ctx) + : OpRewritePattern{ctx} {} + + mlir::LogicalResult + matchAndRewrite(hlfir::ParentComponentOp parentComponent, + mlir::PatternRewriter &rewriter) const override { + mlir::Location loc = parentComponent.getLoc(); + mlir::Type resultType = parentComponent.getType(); + if (!parentComponent.getType().isa()) { + mlir::Value baseAddr = parentComponent.getMemref(); + // Scalar parent component ref without any length type parameters. The + // input may be a fir.class if it is polymorphic, since this is a scalar + // and the output will be monomorphic, the base address can be extracted + // from the fir.class. + if (baseAddr.getType().isa()) + baseAddr = rewriter.create(loc, baseAddr); + rewriter.replaceOpWithNewOp(parentComponent, resultType, + baseAddr); + return mlir::success(); + } + // Array parent component ref or PDTs. + hlfir::Entity base{parentComponent.getMemref()}; + mlir::Value baseAddr = base.getBase(); + if (!baseAddr.getType().isa()) { + // Embox cannot directly be used to address parent components: it expects + // the output type to match the input type when there are no slices. When + // the types have at least one component, a slice to the first element can + // be built, and the result set to the parent component type. Just create + // a fir.box with the base for now since this covers all cases. + mlir::Type baseBoxType = + fir::BoxType::get(base.getElementOrSequenceType()); + assert(!base.hasLengthParameters() && + "base must be a box if it has any type parameters"); + baseAddr = rewriter.create( + loc, baseBoxType, baseAddr, parentComponent.getShape(), + /*slice=*/mlir::Value{}, /*typeParams=*/mlir::ValueRange{}); + } + rewriter.replaceOpWithNewOp(parentComponent, resultType, + baseAddr, + /*shape=*/mlir::Value{}, + /*slice=*/mlir::Value{}); + return mlir::success(); + } +}; + class NoReassocOpConversion : public mlir::OpRewritePattern { public: @@ -546,7 +594,8 @@ mlir::RewritePatternSet patterns(context); patterns.insert(context); + NoReassocOpConversion, NullOpConversion, + ParentComponentOpConversion>(context); mlir::ConversionTarget target(*context); target.addIllegalDialect(); target.markUnknownOpDynamicallyLegal( diff --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir --- a/flang/test/HLFIR/invalid.fir +++ b/flang/test/HLFIR/invalid.fir @@ -389,3 +389,51 @@ hlfir.assign %arg1 to %arg0 realloc keep_lhs_len : !fir.box>, !fir.ref>>> return } + +// ----- +func.func @bad_parent_comp1(%arg0: !fir.box>>) { + // expected-error@+1 {{'hlfir.parent_comp' op must be provided a shape if and only if the base is an array}} + %2 = hlfir.parent_comp %arg0 : (!fir.box>>) -> !fir.box>> + return +} + +// ----- +func.func @bad_parent_comp2(%arg0: !fir.box>>) { + %c10 = arith.constant 10 : index + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + // expected-error@+1 {{'hlfir.parent_comp' op result type rank must match input type rank}} + %2 = hlfir.parent_comp %arg0 shape %1 : (!fir.box>>, !fir.shape<1>) -> !fir.box>> + return +} + +// ----- +func.func @bad_parent_comp3(%arg0: !fir.box>>) { + %c10 = arith.constant 10 : index + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + // expected-error@+1 {{'hlfir.parent_comp' op result type extents are inconsistent with memref type}} + %2 = hlfir.parent_comp %arg0 shape %1 : (!fir.box>>, !fir.shape<1>) -> !fir.box>> + return +} + +// ----- +func.func @bad_parent_comp4(%arg0: !fir.ref>) { + // expected-error@+1 {{'hlfir.parent_comp' op result type and input type must be derived types}} + %1 = hlfir.parent_comp %arg0 : (!fir.ref>) -> !fir.ref + return +} + +// ----- +func.func @bad_parent_comp5(%arg0: !fir.class>) { + // expected-error@+1 {{'hlfir.parent_comp' op result type must not be polymorphic}} + %2 = hlfir.parent_comp %arg0 : (!fir.class>) -> !fir.class> + return +} + +// ----- +func.func @bad_parent_comp6(%arg0: !fir.box>>) { + %c10 = arith.constant 10 : index + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + // expected-error@+1 {{'hlfir.parent_comp' op result type must be a fir.box if the result is an array or has length parameters}} + %2 = hlfir.parent_comp %arg0 shape %1 : (!fir.box>>, !fir.shape<1>) -> !fir.ref>> + return +} diff --git a/flang/test/HLFIR/parent_comp-codegen.fir b/flang/test/HLFIR/parent_comp-codegen.fir new file mode 100644 --- /dev/null +++ b/flang/test/HLFIR/parent_comp-codegen.fir @@ -0,0 +1,44 @@ +// Test hlfir.parent_comp code generation to FIR +// RUN: fir-opt %s -convert-hlfir-to-fir | FileCheck %s + +func.func @test_scalar(%arg0: !fir.ref>) { + %1 = hlfir.parent_comp %arg0 : (!fir.ref>) -> !fir.ref> + return +} +// CHECK-LABEL: func.func @test_scalar( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>) { +// CHECK: fir.convert %[[VAL_0]] : (!fir.ref>) -> !fir.ref> + +func.func @test_scalar_polymorphic(%arg0: !fir.class>) { + %1 = hlfir.parent_comp %arg0 : (!fir.class>) -> !fir.ref> + return +} +// CHECK-LABEL: func.func @test_scalar_polymorphic( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.class>) { +// CHECK: %[[VAL_1:.*]] = fir.box_addr %[[VAL_0]] : (!fir.class>) -> !fir.ref> +// CHECK: fir.convert %[[VAL_1]] : (!fir.ref>) -> !fir.ref> + +func.func @test_array(%arg0: !fir.ref>>) { + %c10 = arith.constant 10 : index + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2 = hlfir.parent_comp %arg0 shape %1 : (!fir.ref>>, !fir.shape<1>) -> !fir.box>> + return +} +// CHECK-LABEL: func.func @test_array( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>>) { +// CHECK: %[[VAL_1:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_3:.*]] = fir.embox %[[VAL_0]](%[[VAL_2]]) : (!fir.ref>>, !fir.shape<1>) -> !fir.box>> +// CHECK: fir.rebox %[[VAL_3]] : (!fir.box>>) -> !fir.box>> + +func.func @test_array_polymorphic(%arg0: !fir.class>>) { + %c10 = arith.constant 10 : index + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2 = hlfir.parent_comp %arg0 shape %1 : (!fir.class>>, !fir.shape<1>) -> !fir.box>> + return +} +// CHECK-LABEL: func.func @test_array_polymorphic( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.class>>) { +// CHECK: %[[VAL_1:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1> +// CHECK: fir.rebox %[[VAL_0]] : (!fir.class>>) -> !fir.box>> diff --git a/flang/test/HLFIR/parent_comp.fir b/flang/test/HLFIR/parent_comp.fir new file mode 100644 --- /dev/null +++ b/flang/test/HLFIR/parent_comp.fir @@ -0,0 +1,42 @@ +// Test hlfir.parent_comp operation parse, verify (no errors), and unparse. +// RUN: fir-opt %s | fir-opt | FileCheck %s + +func.func @test_scalar(%arg0: !fir.ref>) { + %1 = hlfir.parent_comp %arg0 : (!fir.ref>) -> !fir.ref> + return +} +// CHECK-LABEL: func.func @test_scalar( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>) { +// CHECK: hlfir.parent_comp %[[VAL_0]] : (!fir.ref>) -> !fir.ref> + +func.func @test_scalar_polymorphic(%arg0: !fir.class>) { + %1 = hlfir.parent_comp %arg0 : (!fir.class>) -> !fir.ref> + return +} +// CHECK-LABEL: func.func @test_scalar_polymorphic( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.class>) { +// CHECK: hlfir.parent_comp %[[VAL_0]] : (!fir.class>) -> !fir.ref> + +func.func @test_array(%arg0: !fir.ref>>) { + %c10 = arith.constant 10 : index + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2 = hlfir.parent_comp %arg0 shape %1 : (!fir.ref>>, !fir.shape<1>) -> !fir.box>> + return +} +// CHECK-LABEL: func.func @test_array( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>>) { +// CHECK: %[[VAL_1:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1> +// CHECK: hlfir.parent_comp %[[VAL_0]] shape %[[VAL_2]] : (!fir.ref>>, !fir.shape<1>) -> !fir.box>> + +func.func @test_array_polymorphic(%arg0: !fir.class>>) { + %c10 = arith.constant 10 : index + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2 = hlfir.parent_comp %arg0 shape %1 : (!fir.class>>, !fir.shape<1>) -> !fir.box>> + return +} +// CHECK-LABEL: func.func @test_array_polymorphic( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.class>>) { +// CHECK: %[[VAL_1:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1> +// CHECK: hlfir.parent_comp %[[VAL_0]] shape %[[VAL_2]] : (!fir.class>>, !fir.shape<1>) -> !fir.box>>