Index: flang/include/flang/Optimizer/Builder/HLFIRTools.h =================================================================== --- flang/include/flang/Optimizer/Builder/HLFIRTools.h +++ flang/include/flang/Optimizer/Builder/HLFIRTools.h @@ -374,14 +374,15 @@ mlir::Location, fir::FirOpBuilder &, mlir::ValueRange)>; /// Generate an hlfir.elementalOp given call back to generate the element /// value at for each iteration. -/// If exprType is specified, this will be the return type of the elemental op -hlfir::ElementalOp genElementalOp(mlir::Location loc, - fir::FirOpBuilder &builder, - mlir::Type elementType, mlir::Value shape, - mlir::ValueRange typeParams, - const ElementalKernelGenerator &genKernel, - bool isUnordered = false, - mlir::Type exprType = mlir::Type{}); +/// If exprType is specified, this will be the return type of the elemental op. +/// If exprType is not specified, the resulting expression type is computed +/// from the given \p elementType and \p shape, and the type is polymorphic +/// if \p polymorphicMold is present. +hlfir::ElementalOp genElementalOp( + mlir::Location loc, fir::FirOpBuilder &builder, mlir::Type elementType, + mlir::Value shape, mlir::ValueRange typeParams, + const ElementalKernelGenerator &genKernel, bool isUnordered = false, + mlir::Value polymorphicMold = {}, mlir::Type exprType = mlir::Type{}); /// Structure to describe a loop nest. struct LoopNest { Index: flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h =================================================================== --- flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h +++ flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h @@ -87,6 +87,7 @@ bool isI1Type(mlir::Type); // scalar i1 or logical, or sequence of logical (via (boxed?) array or expr) bool isMaskArgument(mlir::Type); +bool isPolymorphicObject(mlir::Type); /// If an expression's extents are known at compile time, generate a fir.shape /// for this expression. Otherwise return {} Index: flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td =================================================================== --- flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td +++ flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td @@ -149,6 +149,11 @@ def AnyFortranLogicalArrayObject : Type; +def IsPolymorphicObjectPred + : CPred<"::hlfir::isPolymorphicObject($_self)">; +def AnyPolymorphicObject : Type; + def hlfir_CharExtremumPredicateAttr : I32EnumAttr< "CharExtremumPredicate", "", [ Index: flang/include/flang/Optimizer/HLFIR/HLFIROps.td =================================================================== --- flang/include/flang/Optimizer/HLFIR/HLFIROps.td +++ flang/include/flang/Optimizer/HLFIR/HLFIROps.td @@ -740,7 +740,7 @@ let cppNamespace = "hlfir"; } -def hlfir_ElementalOp : hlfir_Op<"elemental", [RecursiveMemoryEffects, hlfir_ElementalOpInterface]> { +def hlfir_ElementalOp : hlfir_Op<"elemental", [RecursiveMemoryEffects, hlfir_ElementalOpInterface, AttrSizedOperandSegments]> { let summary = "elemental expression"; let description = [{ Represent an elemental expression as a function of the indices. @@ -753,6 +753,12 @@ The shape and typeparams operands represent the extents and type parameters of the resulting array value. + The optional mold is an entity carrying the information about + the dynamic type of the polymorphic result. Note that the shape + of the mold does not necessarily match the shape of the result, + for example, the result of `merge(poly_scalar1, poly_scalar2, mask_array)` + will have the shape of `mask_array` and the dynamic type of `poly_scalar*`. + The unordered attribute can be set to allow out of order processing of the indices. This is safe only if the operations in the body of the elemental do not have side effects. @@ -775,6 +781,7 @@ let arguments = (ins AnyShapeType:$shape, + Optional:$mold, Variadic:$typeparams, OptionalAttr:$unordered ); @@ -783,7 +790,8 @@ let regions = (region SizedRegion<1>:$region); let assemblyFormat = [{ - $shape (`typeparams` $typeparams^)? (`unordered` $unordered^)? + $shape (`mold` $mold^)? (`typeparams` $typeparams^)? + (`unordered` $unordered^)? attr-dict `:` functional-type(operands, results) $region }]; @@ -808,10 +816,12 @@ let skipDefaultBuilders = 1; let builders = [ OpBuilder<(ins "mlir::Type":$result_type, "mlir::Value":$shape, + CArg<"mlir::Value", "{}">:$mold, CArg<"mlir::ValueRange", "{}">:$typeparams, CArg<"bool", "false">:$isUnordered)> ]; + let hasVerifier = 1; } def hlfir_YieldElementOp : hlfir_Op<"yield_element", [Terminator, HasParent<"ElementalOp">, Pure]> { Index: flang/lib/Lower/ConvertArrayConstructor.cpp =================================================================== --- flang/lib/Lower/ConvertArrayConstructor.cpp +++ flang/lib/Lower/ConvertArrayConstructor.cpp @@ -214,9 +214,9 @@ assert(!elementalOp && "expected only one implied-do"); mlir::Value one = builder.createIntegerConstant(loc, builder.getIndexType(), 1); - elementalOp = - builder.create(loc, exprType, shape, lengthParams, - /*isUnordered=*/true); + elementalOp = builder.create( + loc, exprType, shape, + /*mold=*/nullptr, lengthParams, /*isUnordered=*/true); builder.setInsertionPointToStart(elementalOp.getBody()); // implied-do-index = lower+((i-1)*stride) mlir::Value diff = builder.create( Index: flang/lib/Optimizer/Builder/HLFIRTools.cpp =================================================================== --- flang/lib/Optimizer/Builder/HLFIRTools.cpp +++ flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -737,16 +737,15 @@ isPolymorphic); } -hlfir::ElementalOp -hlfir::genElementalOp(mlir::Location loc, fir::FirOpBuilder &builder, - mlir::Type elementType, mlir::Value shape, - mlir::ValueRange typeParams, - const ElementalKernelGenerator &genKernel, - bool isUnordered, mlir::Type exprType) { +hlfir::ElementalOp hlfir::genElementalOp( + mlir::Location loc, fir::FirOpBuilder &builder, mlir::Type elementType, + mlir::Value shape, mlir::ValueRange typeParams, + const ElementalKernelGenerator &genKernel, bool isUnordered, + mlir::Value polymorphicMold, mlir::Type exprType) { if (!exprType) - exprType = getArrayExprType(elementType, shape, false); + exprType = getArrayExprType(elementType, shape, !!polymorphicMold); auto elementalOp = builder.create( - loc, exprType, shape, typeParams, isUnordered); + loc, exprType, shape, polymorphicMold, typeParams, isUnordered); auto insertPt = builder.saveInsertionPoint(); builder.setInsertionPointToStart(elementalOp.getBody()); mlir::Value elementResult = genKernel(loc, builder, elementalOp.getIndices()); Index: flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp =================================================================== --- flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp +++ flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp @@ -181,6 +181,13 @@ return mlir::isa(elementType) || isI1Type(elementType); } +bool hlfir::isPolymorphicObject(mlir::Type type) { + if (auto exprType = mlir::dyn_cast(type)) + return exprType.isPolymorphic(); + + return fir::isPolymorphicType(type); +} + mlir::Value hlfir::genExprShape(mlir::OpBuilder &builder, const mlir::Location &loc, const hlfir::ExprType &expr) { Index: flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp =================================================================== --- flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -1036,10 +1036,17 @@ void hlfir::ElementalOp::build(mlir::OpBuilder &builder, mlir::OperationState &odsState, mlir::Type resultType, mlir::Value shape, - mlir::ValueRange typeparams, bool isUnordered) { + mlir::Value mold, mlir::ValueRange typeparams, + bool isUnordered) { odsState.addOperands(shape); + if (mold) + odsState.addOperands(mold); odsState.addOperands(typeparams); odsState.addTypes(resultType); + odsState.addAttribute( + getOperandSegmentSizesAttrName(odsState.name), + builder.getDenseI32ArrayAttr({/*shape=*/1, (mold ? 1 : 0), + static_cast(typeparams.size())})); if (isUnordered) odsState.addAttribute(getUnorderedAttrName(odsState.name), isUnordered ? builder.getUnitAttr() : nullptr); @@ -1057,6 +1064,16 @@ return mlir::cast(getBody()->back()).getElementValue(); } +mlir::LogicalResult hlfir::ElementalOp::verify() { + mlir::Value mold = getMold(); + hlfir::ExprType resultType = mlir::cast(getType()); + if (!!mold != resultType.isPolymorphic()) + return emitOpError("result must be polymorphic when mold is present " + "and vice versa"); + + return mlir::success(); +} + //===----------------------------------------------------------------------===// // ApplyOp //===----------------------------------------------------------------------===// Index: flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp =================================================================== --- flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -58,7 +58,8 @@ }; hlfir::ElementalOp elementalOp = hlfir::genElementalOp( loc, builder, elementType, resultShape, typeParams, genKernel, - /*isUnordered=*/true, transpose.getResult().getType()); + /*isUnordered=*/true, /*polymorphicMold=*/nullptr, + transpose.getResult().getType()); // it wouldn't be safe to replace block arguments with a different // hlfir.expr type. Types can differ due to differing amounts of shape Index: flang/test/HLFIR/elemental.fir =================================================================== --- flang/test/HLFIR/elemental.fir +++ flang/test/HLFIR/elemental.fir @@ -99,3 +99,45 @@ // CHECK: } // CHECK: return // CHECK: } + +func.func @polymorphic_mold_var(%arg0: !fir.class>>, %shape : index) { + %3 = fir.shape %shape : (index) -> !fir.shape<1> + %4 = hlfir.elemental %3 mold %arg0 unordered : (!fir.shape<1>, !fir.class>>) -> !hlfir.expr?> { + ^bb0(%arg2: index): + %6 = fir.undefined !hlfir.expr?> + hlfir.yield_element %6 : !hlfir.expr?> + } + return +} +// CHECK-LABEL: func.func @polymorphic_mold_var( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.class>>, %[[VAL_1:.*]]: index) { +// CHECK: %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_3:.*]] = hlfir.elemental %[[VAL_2]] mold %[[VAL_0]] unordered : (!fir.shape<1>, !fir.class>>) -> !hlfir.expr?> { +// CHECK: ^bb0(%[[VAL_4:.*]]: index): +// CHECK: %[[VAL_5:.*]] = fir.undefined !hlfir.expr?> +// CHECK: hlfir.yield_element %[[VAL_5]] : !hlfir.expr?> +// CHECK: } +// CHECK: return +// CHECK: } + +func.func @polymorphic_mold_expr(%shape : index) { + %3 = fir.shape %shape : (index) -> !fir.shape<1> + %mold = fir.undefined !hlfir.expr?> + %4 = hlfir.elemental %3 mold %mold unordered : (!fir.shape<1>, !hlfir.expr?>) -> !hlfir.expr?> { + ^bb0(%arg2: index): + %6 = fir.undefined !hlfir.expr?> + hlfir.yield_element %6 : !hlfir.expr?> + } + return +} +// CHECK-LABEL: func.func @polymorphic_mold_expr( +// CHECK-SAME: %[[VAL_0:.*]]: index) { +// CHECK: %[[VAL_1:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_2:.*]] = fir.undefined !hlfir.expr?> +// CHECK: %[[VAL_3:.*]] = hlfir.elemental %[[VAL_1]] mold %[[VAL_2]] unordered : (!fir.shape<1>, !hlfir.expr?>) -> !hlfir.expr?> { +// CHECK: ^bb0(%[[VAL_4:.*]]: index): +// CHECK: %[[VAL_5:.*]] = fir.undefined !hlfir.expr?> +// CHECK: hlfir.yield_element %[[VAL_5]] : !hlfir.expr?> +// CHECK: } +// CHECK: return +// CHECK: } Index: flang/test/HLFIR/invalid.fir =================================================================== --- flang/test/HLFIR/invalid.fir +++ flang/test/HLFIR/invalid.fir @@ -961,3 +961,51 @@ %1 = hlfir.get_length %arg0 : (!hlfir.expr>) -> index return } + +// ----- +func.func @elemental_poly_1(%arg0: !fir.box>>, %shape : index) { + %3 = fir.shape %shape : (index) -> !fir.shape<1> + // expected-error@+1 {{'hlfir.elemental' op operand #1 must be any polymorphic object, but got '!fir.box>>'}} + %4 = hlfir.elemental %3 mold %arg0 unordered : (!fir.shape<1>, !fir.box>>) -> !hlfir.expr?> { + ^bb0(%arg2: index): + %6 = fir.undefined !hlfir.expr?> + hlfir.yield_element %6 : !hlfir.expr?> + } + return +} + +// ----- +func.func @elemental_poly_2(%arg0: !hlfir.expr>, %shape : index) { + %3 = fir.shape %shape : (index) -> !fir.shape<1> + // expected-error@+1 {{'hlfir.elemental' op operand #1 must be any polymorphic object, but got '!hlfir.expr>'}} + %4 = hlfir.elemental %3 mold %arg0 unordered : (!fir.shape<1>, !hlfir.expr>) -> !hlfir.expr?> { + ^bb0(%arg2: index): + %6 = fir.undefined !hlfir.expr?> + hlfir.yield_element %6 : !hlfir.expr?> + } + return +} + +// ----- +func.func @elemental_poly_3(%arg0: !hlfir.expr?>, %shape : index) { + %3 = fir.shape %shape : (index) -> !fir.shape<1> +// expected-error@+1 {{'hlfir.elemental' op result must be polymorphic when mold is present and vice versa}} + %4 = hlfir.elemental %3 mold %arg0 unordered : (!fir.shape<1>, !hlfir.expr?>) -> !hlfir.expr> { + ^bb0(%arg2: index): + %6 = fir.undefined !hlfir.expr> + hlfir.yield_element %6 : !hlfir.expr> + } + return +} + +// ----- +func.func @elemental_poly_4(%shape : index) { + %3 = fir.shape %shape : (index) -> !fir.shape<1> +// expected-error@+1 {{'hlfir.elemental' op result must be polymorphic when mold is present and vice versa}} + %4 = hlfir.elemental %3 unordered : (!fir.shape<1>) -> !hlfir.expr?> { + ^bb0(%arg2: index): + %6 = fir.undefined !hlfir.expr?> + hlfir.yield_element %6 : !hlfir.expr?> + } + return +}