diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h --- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h +++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h @@ -227,6 +227,12 @@ mlir::Value genShape(mlir::Location loc, fir::FirOpBuilder &builder, Entity entity); +/// Generate a vector of extents with index type from a fir.shape +/// of fir.shape_shift value. +llvm::SmallVector getIndexExtents(mlir::Location loc, + fir::FirOpBuilder &builder, + mlir::Value shape); + /// Read length parameters into result if this entity has any. void genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder, Entity entity, @@ -260,6 +266,10 @@ std::pair> genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder, mlir::ValueRange extents); +inline std::pair> +genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder, mlir::Value shape) { + return genLoopNest(loc, builder, getIndexExtents(loc, builder, shape)); +} /// Inline the body of an hlfir.elemental at the current insertion point /// given a list of one based indices. This generates the computation diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp --- a/flang/lib/Lower/CallInterface.cpp +++ b/flang/lib/Lower/CallInterface.cpp @@ -1110,7 +1110,14 @@ const { if (!characteristics) return true; - return characteristics->GetIntent() != Fortran::common::Intent::In; + if (characteristics->GetIntent() == Fortran::common::Intent::In) + return false; + const auto *dummy = + std::get_if( + &characteristics->u); + return !dummy || + !dummy->attrs.test( + Fortran::evaluate::characteristics::DummyDataObject::Attr::Value); } template bool Fortran::lower::CallInterface::PassedEntity::mayBeReadByCall() const { diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp --- a/flang/lib/Lower/ConvertCall.cpp +++ b/flang/lib/Lower/ConvertCall.cpp @@ -429,6 +429,15 @@ namespace { class CallBuilder { +private: + struct PreparedActualArgument { + hlfir::Entity actual; + bool handleDynamicOptional; + }; + using PreparedActualArguments = + llvm::SmallVector>; + using PassBy = Fortran::lower::CallerInterface::PassEntityBy; + public: CallBuilder(mlir::Location loc, Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symMap, @@ -439,20 +448,18 @@ gen(const Fortran::evaluate::ProcedureRef &procRef, llvm::Optional resultType) { mlir::Location loc = getLoc(); - fir::FirOpBuilder &builder = getBuilder(); - if (isElementalProcWithArrayArgs(procRef)) - TODO(loc, "lowering elemental call to HLFIR"); - if (auto *specific = procRef.proc().GetSpecificIntrinsic()) + if (auto *specific = procRef.proc().GetSpecificIntrinsic()) { + if (isElementalProcWithArrayArgs(procRef)) + TODO(loc, "lowering elemental intrinsic call to HLFIR"); return genIntrinsicRef(procRef, resultType, *specific); + } if (isStatementFunctionCall(procRef)) TODO(loc, "lowering Statement function call to HLFIR"); Fortran::lower::CallerInterface caller(procRef, converter); - using PassBy = Fortran::lower::CallerInterface::PassEntityBy; mlir::FunctionType callSiteType = caller.genFunctionType(); - llvm::SmallVector> - loweredActuals; + PreparedActualArguments loweredActuals; // Lower the actual arguments for (const Fortran::lower::CallInterface< Fortran::lower::CallerInterface>::PassedEntity &arg : @@ -461,41 +468,62 @@ const auto *expr = actual->UnwrapExpr(); if (!expr) TODO(loc, "assumed type actual argument"); - loweredActuals.emplace_back(Fortran::lower::convertExprToHLFIR( - loc, getConverter(), *expr, getSymMap(), getStmtCtx())); + + const bool handleDynamicOptional = + arg.isOptional() && Fortran::evaluate::MayBePassedAsAbsentOptional( + *expr, getConverter().getFoldingContext()); + auto loweredActual = Fortran::lower::convertExprToHLFIR( + loc, getConverter(), *expr, getSymMap(), getStmtCtx()); + loweredActuals.emplace_back( + PreparedActualArgument{loweredActual, handleDynamicOptional}); } else { // Optional dummy argument for which there is no actual argument. loweredActuals.emplace_back(std::nullopt); } + if (isElementalProcWithArrayArgs(procRef)) { + bool isImpure = false; + if (const Fortran::semantics::Symbol *procSym = + procRef.proc().GetSymbol()) + isImpure = !Fortran::semantics::IsPureProcedure(*procSym); + return genElementalUserCall(loweredActuals, caller, resultType, + callSiteType, isImpure); + } + return genUserCall(loweredActuals, caller, resultType, callSiteType); + } +private: + llvm::Optional + genUserCall(PreparedActualArguments &loweredActuals, + Fortran::lower::CallerInterface &caller, + llvm::Optional resultType, + mlir::FunctionType callSiteType) { + mlir::Location loc = getLoc(); + fir::FirOpBuilder &builder = getBuilder(); llvm::SmallVector exprAssociations; - for (auto [actual, arg] : + for (auto [preparedActual, arg] : llvm::zip(loweredActuals, caller.getPassedArguments())) { mlir::Type argTy = callSiteType.getInput(arg.firArgument); - if (!actual) { + if (!preparedActual) { // Optional dummy argument for which there is no actual argument. caller.placeInput(arg, builder.create(loc, argTy)); continue; } - + hlfir::Entity actual = preparedActual->actual; const auto *expr = arg.entity->UnwrapExpr(); if (!expr) TODO(loc, "assumed type actual argument"); - const bool actualMayBeDynamicallyAbsent = - arg.isOptional() && Fortran::evaluate::MayBePassedAsAbsentOptional( - *expr, getConverter().getFoldingContext()); - if (actualMayBeDynamicallyAbsent) + if (preparedActual->handleDynamicOptional) TODO(loc, "passing optional arguments in HLFIR"); const bool isSimplyContiguous = - actual->isScalar() || Fortran::evaluate::IsSimplyContiguous( - *expr, getConverter().getFoldingContext()); + actual.isScalar() || Fortran::evaluate::IsSimplyContiguous( + *expr, getConverter().getFoldingContext()); switch (arg.passBy) { case PassBy::Value: { // True pass-by-value semantics. - auto value = hlfir::loadTrivialScalar(loc, builder, *actual); + auto value = hlfir::loadTrivialScalar(loc, builder, actual); if (!value.isValue()) TODO(loc, "Passing CPTR an CFUNCTPTR VALUE in HLFIR"); caller.placeInput(arg, builder.createConvert(loc, argTy, value)); @@ -506,7 +534,7 @@ } break; case PassBy::BaseAddress: case PassBy::BoxChar: { - hlfir::Entity entity = *actual; + hlfir::Entity entity = actual; if (entity.isVariable()) { entity = hlfir::derefPointersAndAllocatables(loc, builder, entity); // Copy-in non contiguous variable @@ -556,11 +584,88 @@ builder.create(loc, associate); if (!fir::getBase(result)) return std::nullopt; // subroutine call. - return extendedValueToHlfirEntity(result, ".tmp.func_result"); // TODO: "move" non pointer results into hlfir.expr. + return extendedValueToHlfirEntity(result, ".tmp.func_result"); + } + + llvm::Optional + genElementalUserCall(PreparedActualArguments &loweredActuals, + Fortran::lower::CallerInterface &caller, + llvm::Optional resultType, + mlir::FunctionType callSiteType, bool isImpure) { + mlir::Location loc = getLoc(); + fir::FirOpBuilder &builder = getBuilder(); + assert(loweredActuals.size() == caller.getPassedArguments().size()); + unsigned numArgs = loweredActuals.size(); + // Step 1: dereference pointers/allocatables and compute elemental shape. + mlir::Value shape; + // 10.1.4 p5. Impure elemental procedures must be called in element order. + bool mustBeOrdered = isImpure; + for (unsigned i = 0; i < numArgs; ++i) { + const auto &arg = caller.getPassedArguments()[i]; + auto &preparedActual = loweredActuals[i]; + if (preparedActual) { + hlfir::Entity &actual = preparedActual->actual; + // Elemental procedure dummy arguments cannot be pointer/allocatables + // (C15100), so it is safe to dereference any pointer or allocatable + // actual argument now instead of doing this inside the elemental + // region. + actual = hlfir::derefPointersAndAllocatables(loc, builder, actual); + // Better to load scalars outside of the loop when possible. + if (!preparedActual->handleDynamicOptional && + (arg.passBy == PassBy::Value || + arg.passBy == PassBy::BaseAddressValueAttribute)) + actual = hlfir::loadTrivialScalar(loc, builder, actual); + // TODO: merge shape instead of using the first one. + if (!shape && actual.isArray()) { + if (preparedActual->handleDynamicOptional) + TODO(loc, "deal with optional with shapes in HLFIR elemental call"); + shape = hlfir::genShape(loc, builder, actual); + } + // 15.8.3 p1. Elemental procedure with intent(out)/intent(inout) + // arguments must be called in element order. + if (arg.mayBeModifiedByCall()) + mustBeOrdered = true; + } + } + assert(shape && + "elemental array calls must have at least one array arguments"); + if (mustBeOrdered) + TODO(loc, "ordered elemental calls in HLFIR"); + if (!resultType) { + // Subroutine case. Generate call inside loop nest. + auto [innerLoop, oneBasedIndices] = + hlfir::genLoopNest(loc, builder, shape); + auto insPt = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(innerLoop.getBody()); + for (auto &preparedActual : loweredActuals) + if (preparedActual) + preparedActual->actual = hlfir::getElementAt( + loc, builder, preparedActual->actual, oneBasedIndices); + genUserCall(loweredActuals, caller, resultType, callSiteType); + builder.restoreInsertionPoint(insPt); + return std::nullopt; + } + // Function case: generate call inside hlfir.elemental + mlir::Type elementType = hlfir::getFortranElementType(*resultType); + // Get result length parameters. + llvm::SmallVector typeParams; + if (elementType.isa() || + fir::isRecordWithTypeParameters(elementType)) + TODO(loc, "compute elemental function result length parameters in HLFIR"); + auto genKernel = [&](mlir::Location l, fir::FirOpBuilder &b, + mlir::ValueRange oneBasedIndices) -> hlfir::Entity { + for (auto &preparedActual : loweredActuals) + if (preparedActual) + preparedActual->actual = hlfir::getElementAt( + l, b, preparedActual->actual, oneBasedIndices); + return *genUserCall(loweredActuals, caller, resultType, callSiteType); + }; + // TODO: deal with hlfir.elemental result destruction. + return hlfir::EntityWithAttributes{hlfir::genElementalOp( + loc, builder, elementType, shape, typeParams, genKernel)}; } -private: hlfir::EntityWithAttributes genIntrinsicRef(const Fortran::evaluate::ProcedureRef &procRef, llvm::Optional resultType, diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp --- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp +++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -396,6 +396,26 @@ return builder.create(loc, extents); } +llvm::SmallVector +hlfir::getIndexExtents(mlir::Location loc, fir::FirOpBuilder &builder, + mlir::Value shape) { + llvm::SmallVector extents; + if (auto s = shape.getDefiningOp()) { + auto e = s.getExtents(); + extents.insert(extents.end(), e.begin(), e.end()); + } else if (auto s = shape.getDefiningOp()) { + auto e = s.getExtents(); + extents.insert(extents.end(), e.begin(), e.end()); + } else { + // TODO: add fir.get_extent ops on fir.shape<> ops. + TODO(loc, "get extents from fir.shape without fir::ShapeOp parent op"); + } + mlir::Type indexType = builder.getIndexType(); + for (auto &extent : extents) + extent = builder.createConvert(loc, indexType, extent); + return extents; +} + void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder, Entity entity, llvm::SmallVectorImpl &result) { diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -95,26 +95,6 @@ TODO(bufferizedExpr.getLoc(), "general extract storage case"); } -static llvm::SmallVector -getIndexExtents(mlir::Location loc, fir::FirOpBuilder &builder, - mlir::Value shape) { - llvm::SmallVector extents; - if (auto s = shape.getDefiningOp()) { - auto e = s.getExtents(); - extents.insert(extents.end(), e.begin(), e.end()); - } else if (auto s = shape.getDefiningOp()) { - auto e = s.getExtents(); - extents.insert(extents.end(), e.begin(), e.end()); - } else { - // TODO: add fir.get_extent ops on fir.shape<> ops. - TODO(loc, "get extents from fir.shape without fir::ShapeOp parent op"); - } - mlir::Type indexType = builder.getIndexType(); - for (auto &extent : extents) - extent = builder.createConvert(loc, indexType, extent); - return extents; -} - static std::pair createTempFromMold(mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity mold) { @@ -128,7 +108,7 @@ mlir::Type sequenceType = hlfir::getFortranElementOrSequenceType(mold.getType()); shape = hlfir::genShape(loc, builder, mold); - auto extents = getIndexExtents(loc, builder, shape); + auto extents = hlfir::getIndexExtents(loc, builder, shape); alloc = builder.createHeapTemporary(loc, sequenceType, tmpName, extents, lenParams); isHeapAlloc = builder.createBool(loc, true); @@ -369,7 +349,7 @@ builder.setListener(&listener); mlir::Value shape = adaptor.getShape(); - auto extents = getIndexExtents(loc, builder, shape); + auto extents = hlfir::getIndexExtents(loc, builder, shape); auto [temp, cleanup] = createArrayTemp(loc, builder, elemental.getType(), shape, extents, adaptor.getTypeparams()); diff --git a/flang/test/Lower/HLFIR/elemental-user-procedure-ref.f90 b/flang/test/Lower/HLFIR/elemental-user-procedure-ref.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/HLFIR/elemental-user-procedure-ref.f90 @@ -0,0 +1,92 @@ +! Test lowering of user defined elemental procedure reference to HLFIR +! RUN: bbc -emit-fir -hlfir -o - %s 2>&1 | FileCheck %s + +subroutine by_addr(x, y) + integer :: x + real :: y(100) + interface + real elemental function elem(a, b) + integer, intent(in) :: a + real, intent(in) :: b + end function + end interface + call baz(elem(x, y)) +end subroutine +! CHECK-LABEL: func.func @_QPby_addr( +! CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0:.*]] {{.*}}x +! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_1:.*]](%[[VAL_4:[^)]*]]) {{.*}}y +! CHECK: %[[VAL_6:.*]] = hlfir.elemental %[[VAL_4]] : (!fir.shape<1>) -> !hlfir.expr<100xf32> { +! CHECK: ^bb0(%[[VAL_7:.*]]: index): +! CHECK: %[[VAL_8:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[VAL_7]]) : (!fir.ref>, index) -> !fir.ref +! CHECK: %[[VAL_9:.*]] = fir.call @_QPelem(%[[VAL_2]]#1, %[[VAL_8]]) fastmath : (!fir.ref, !fir.ref) -> f32 +! CHECK: hlfir.yield_element %[[VAL_9]] : f32 +! CHECK: } + +subroutine by_value(x, y) + integer :: x + real :: y(10, 20) + interface + real elemental function elem_val(a, b) + integer, value :: a + real, value :: b + end function + end interface + call baz(elem_val(x, y)) +end subroutine +! CHECK-LABEL: func.func @_QPby_value( +! CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0:.*]] {{.*}}x +! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1:.*]](%[[VAL_5:[^)]*]]) {{.*}}y +! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_2]]#0 : !fir.ref +! CHECK: %[[VAL_8:.*]] = hlfir.elemental %[[VAL_5]] : (!fir.shape<2>) -> !hlfir.expr<10x20xf32> { +! CHECK: ^bb0(%[[VAL_9:.*]]: index, %[[VAL_10:.*]]: index): +! CHECK: %[[VAL_11:.*]] = hlfir.designate %[[VAL_6]]#0 (%[[VAL_9]], %[[VAL_10]]) : (!fir.ref>, index, index) -> !fir.ref +! CHECK: %[[VAL_12:.*]] = fir.load %[[VAL_11]] : !fir.ref +! CHECK: %[[VAL_13:.*]] = fir.call @_QPelem_val(%[[VAL_7]], %[[VAL_12]]) fastmath : (i32, f32) -> f32 +! CHECK: hlfir.yield_element %[[VAL_13]] : f32 +! CHECK: } + +subroutine by_boxaddr(x, y) + character(*) :: x + character(*) :: y(100) + interface + real elemental function char_elem(a, b) + character(*), intent(in) :: a + character(*), intent(in) :: b + end function + end interface + call baz2(char_elem(x, y)) +end subroutine +! CHECK-LABEL: func.func @_QPby_boxaddr( +! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2:.*]]#0 typeparams %[[VAL_2]]#1 {{.*}}x +! CHECK: %[[VAL_6:.*]] = arith.constant 100 : index +! CHECK: %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_5:.*]](%[[VAL_7:.*]]) typeparams %[[VAL_4:.*]]#1 {{.*}}y +! CHECK: %[[VAL_9:.*]] = hlfir.elemental %[[VAL_7]] : (!fir.shape<1>) -> !hlfir.expr<100xf32> { +! CHECK: ^bb0(%[[VAL_10:.*]]: index): +! CHECK: %[[VAL_11:.*]] = hlfir.designate %[[VAL_8]]#0 (%[[VAL_10]]) typeparams %[[VAL_4]]#1 : (!fir.box>>, index, index) -> !fir.boxchar<1> +! CHECK: %[[VAL_12:.*]] = fir.call @_QPchar_elem(%[[VAL_3]]#0, %[[VAL_11]]) fastmath : (!fir.boxchar<1>, !fir.boxchar<1>) -> f32 +! CHECK: hlfir.yield_element %[[VAL_12]] : f32 +! CHECK: } + +subroutine sub(x, y) + integer :: x + real :: y(10, 20) + interface + elemental subroutine elem_sub(a, b) + integer, intent(in) :: a + real, intent(in) :: b + end subroutine + end interface + call elem_sub(x, y) +end subroutine +! CHECK-LABEL: func.func @_QPsub( +! CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0:.*]] {{.*}}x +! CHECK: %[[VAL_3:.*]] = arith.constant 10 : index +! CHECK: %[[VAL_4:.*]] = arith.constant 20 : index +! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1:.*]](%[[VAL_5:[^)]*]]) {{.*}}y +! CHECK: %[[VAL_7:.*]] = arith.constant 1 : index +! CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_7]] { +! CHECK: fir.do_loop %[[VAL_9:.*]] = %[[VAL_7]] to %[[VAL_3]] step %[[VAL_7]] { +! CHECK: %[[VAL_10:.*]] = hlfir.designate %[[VAL_6]]#0 (%[[VAL_9]], %[[VAL_8]]) : (!fir.ref>, index, index) -> !fir.ref +! CHECK: fir.call @_QPelem_sub(%[[VAL_2]]#1, %[[VAL_10]]) fastmath : (!fir.ref, !fir.ref) -> () +! CHECK: } +! CHECK: }