diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -1121,22 +1121,38 @@ static Value getReductionInitValue(mlir::Location loc, mlir::Type type, llvm::StringRef reductionOpName, fir::FirOpBuilder &builder) { - if (type.isa()) + assert((fir::isa_integer(type) || fir::isa_real(type) || + type.isa()) && + "only integer, logical and real types are currently supported"); + if (reductionOpName.contains("max")) { + if (auto ty = type.dyn_cast()) { + const llvm::fltSemantics &sem = ty.getFloatSemantics(); + return builder.createRealConstant( + loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); + } + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, minInt); + } else { + if (type.isa()) + return builder.create( + loc, type, + builder.getFloatAttr( + type, (double)getOperationIdentity(reductionOpName, loc))); + + if (type.isa()) { + Value intConst = builder.create( + loc, builder.getI1Type(), + builder.getIntegerAttr(builder.getI1Type(), + getOperationIdentity(reductionOpName, loc))); + return builder.createConvert(loc, type, intConst); + } + return builder.create( loc, type, - builder.getFloatAttr( - type, (double)getOperationIdentity(reductionOpName, loc))); - - if (type.isa()) { - Value intConst = builder.create( - loc, builder.getI1Type(), - builder.getIntegerAttr(builder.getI1Type(), + builder.getIntegerAttr(type, getOperationIdentity(reductionOpName, loc))); - return builder.createConvert(loc, type, intConst); } - return builder.create( - loc, type, - builder.getIntegerAttr(type, getOperationIdentity(reductionOpName, loc))); } template @@ -1150,6 +1166,65 @@ return builder.create(loc, op1, op2); } +static omp::ReductionDeclareOp +createMinimalReductionDecl(fir::FirOpBuilder &builder, + llvm::StringRef reductionOpName, mlir::Type type, + mlir::Location loc) { + mlir::ModuleOp module = builder.getModule(); + mlir::OpBuilder modBuilder(module.getBodyRegion()); + + mlir::omp::ReductionDeclareOp decl = + modBuilder.create(loc, reductionOpName, type); + builder.createBlock(&decl.getInitializerRegion(), + decl.getInitializerRegion().end(), {type}, {loc}); + builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); + Value init = getReductionInitValue(loc, type, reductionOpName, builder); + builder.create(loc, init); + + builder.createBlock(&decl.getReductionRegion(), + decl.getReductionRegion().end(), {type, type}, + {loc, loc}); + + return decl; +} + +/// Creates an OpenMP reduction declaration and inserts it into the provided +/// symbol table. The declaration has a constant initializer with the neutral +/// value `initValue`, and the reduction combiner carried over from `reduce`. +/// TODO: Generalize this for non-integer types, add atomic region. +static omp::ReductionDeclareOp +createReductionDecl(fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, + const Fortran::parser::ProcedureDesignator &procDesignator, + mlir::Type type, mlir::Location loc) { + OpBuilder::InsertionGuard guard(builder); + mlir::ModuleOp module = builder.getModule(); + + auto decl = + module.lookupSymbol(reductionOpName); + if (decl) + return decl; + + decl = createMinimalReductionDecl(builder, reductionOpName, type, loc); + builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); + mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); + mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); + + Value reductionOp; + if (const auto *name{ + Fortran::parser::Unwrap(procDesignator)}) { + if (name->source == "max") { + reductionOp = + getReductionOperation( + builder, type, loc, op1, op2); + } else { + TODO(loc, "Reduction of some intrinsic operators is not supported"); + } + } + + builder.create(loc, reductionOp); + return decl; +} + /// Creates an OpenMP reduction declaration and inserts it into the provided /// symbol table. The declaration has a constant initializer with the neutral /// value `initValue`, and the reduction combiner carried over from `reduce`. @@ -1160,23 +1235,13 @@ mlir::Type type, mlir::Location loc) { OpBuilder::InsertionGuard guard(builder); mlir::ModuleOp module = builder.getModule(); - mlir::OpBuilder modBuilder(module.getBodyRegion()); + auto decl = module.lookupSymbol(reductionOpName); - if (!decl) - decl = - modBuilder.create(loc, reductionOpName, type); - else + if (decl) return decl; - builder.createBlock(&decl.getInitializerRegion(), - decl.getInitializerRegion().end(), {type}, {loc}); - builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); - Value init = getReductionInitValue(loc, type, reductionOpName, builder); - builder.create(loc, init); - builder.createBlock(&decl.getReductionRegion(), - decl.getReductionRegion().end(), {type, type}, - {loc, loc}); + decl = createMinimalReductionDecl(builder, reductionOpName, type, loc); builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); @@ -1284,6 +1349,13 @@ return mlir::omp::ScheduleModifier::none; } +static std::string getReductionName(llvm::StringRef name, mlir::Type ty) { + return (llvm::Twine(name) + + (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) + + llvm::Twine(ty.getIntOrFloatBitWidth())) + .str(); +} + static std::string getReductionName( Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, mlir::Type ty) { @@ -1305,10 +1377,7 @@ break; } - return (llvm::Twine(reductionName) + - (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) + - llvm::Twine(ty.getIntOrFloatBitWidth())) - .str(); + return getReductionName(reductionName, ty); } static void genOMP(Fortran::lower::AbstractConverter &converter, @@ -1443,9 +1512,34 @@ } } } - } else { - TODO(currentLocation, - "Reduction of intrinsic procedures is not supported"); + } else if (auto reductionIntrinsic = + std::get_if( + &redOperator.u)) { + if (const auto *name{Fortran::parser::Unwrap( + reductionIntrinsic)}) { + if (name->source != "max") { + TODO(currentLocation, + "Reduction of intrinsic procedures is not supported"); + } + for (const auto &ompObject : objectList.v) { + if (const auto *name{Fortran::parser::Unwrap( + ompObject)}) { + if (const auto *symbol{name->symbol}) { + mlir::Value symVal = converter.getSymbolAddress(*symbol); + mlir::Type redType = + symVal.getType().cast().getEleTy(); + reductionVars.push_back(symVal); + assert(redType.isIntOrIndexOrFloat() && + "Unsupported reduction type"); + decl = createReductionDecl( + firOpBuilder, getReductionName("max", redType), + *reductionIntrinsic, redType, currentLocation); + reductionDeclSymbols.push_back(SymbolRefAttr::get( + firOpBuilder.getContext(), decl.getSymName())); + } + } + } + } } } else if (const auto &simdlenClause = std::get_if( @@ -2104,6 +2198,21 @@ ompDeclConstruct.u); } +static mlir::Operation *getCompareFromReductionOp(mlir::Operation *reductionOp, + mlir::Value loadVal) { + for (auto reductionOperand : reductionOp->getOperands()) { + if (auto compareOp = reductionOperand.getDefiningOp()) { + if (compareOp->getOperand(0) == loadVal || + compareOp->getOperand(1) == loadVal) + assert((mlir::isa(compareOp) || + mlir::isa(compareOp)) && + "Expected comparison not found in reduction intrinsic"); + return compareOp; + } + } + return nullptr; +} + // Generate an OpenMP reduction operation. // TODO: Currently assumes it is either an integer addition/multiplication // reduction, or a logical and reduction. Generalize this for various reduction @@ -2170,6 +2279,40 @@ } } } + } else if (auto reductionIntrinsic = + std::get_if( + &redOperator.u)) { + if (const auto *name{Fortran::parser::Unwrap( + reductionIntrinsic)}) { + if (name->source != "max") { + continue; + } + for (const auto &ompObject : objectList.v) { + if (const auto *name{Fortran::parser::Unwrap( + ompObject)}) { + if (const auto *symbol{name->symbol}) { + mlir::Value reductionVal = converter.getSymbolAddress(*symbol); + for (mlir::OpOperand &reductionValUse : + reductionVal.getUses()) { + if (auto loadOp = mlir::dyn_cast( + reductionValUse.getOwner())) { + mlir::Value loadVal = loadOp.getRes(); + // Max is lowered as a compare -> select. + // Match the pattern here. + mlir::Operation *reductionOp = + findReductionChain(loadVal, &reductionVal); + assert(mlir::isa(reductionOp) && + "Selection Op not found in reduction intrinsic"); + mlir::Operation *compareOp = + getCompareFromReductionOp(reductionOp, loadVal); + updateReduction(compareOp, firOpBuilder, loadVal, + reductionVal); + } + } + } + } + } + } } } } diff --git a/flang/test/Lower/OpenMP/Todo/reduction-max.f90 b/flang/test/Lower/OpenMP/Todo/reduction-max.f90 deleted file mode 100644 --- a/flang/test/Lower/OpenMP/Todo/reduction-max.f90 +++ /dev/null @@ -1,16 +0,0 @@ -! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s -! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s - -! CHECK: not yet implemented: Reduction of intrinsic procedures is not supported -subroutine reduction_max(y) - integer :: x, y(:) - x = 0 - !$omp parallel - !$omp do reduction(max:x) - do i=1, 100 - x = max(x, y(i)) - end do - !$omp end do - !$omp end parallel - print *, x -end subroutine diff --git a/flang/test/Lower/OpenMP/wsloop-reduction-max.f90 b/flang/test/Lower/OpenMP/wsloop-reduction-max.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/wsloop-reduction-max.f90 @@ -0,0 +1,66 @@ +! RUN: bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s +! RUN: %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s + +!CHECK: omp.reduction.declare @[[MAX_DECLARE_F:.*]] : f32 init { +!CHECK: %[[MINIMUM_VAL_F:.*]] = arith.constant -3.40282347E+38 : f32 +!CHECK: omp.yield(%[[MINIMUM_VAL_F]] : f32) +!CHECK: combiner +!CHECK: ^bb0(%[[ARG0_F:.*]]: f32, %[[ARG1_F:.*]]: f32): +!CHECK: %[[COMB_VAL_F:.*]] = arith.maxf %[[ARG0_F]], %[[ARG1_F]] {{.*}}: f32 +!CHECK: omp.yield(%[[COMB_VAL_F]] : f32) + +!CHECK: omp.reduction.declare @[[MAX_DECLARE_I:.*]] : i32 init { +!CHECK: %[[MINIMUM_VAL_I:.*]] = arith.constant -2147483648 : i32 +!CHECK: omp.yield(%[[MINIMUM_VAL_I]] : i32) +!CHECK: combiner +!CHECK: ^bb0(%[[ARG0_I:.*]]: i32, %[[ARG1_I:.*]]: i32): +!CHECK: %[[COMB_VAL_I:.*]] = arith.maxsi %[[ARG0_I]], %[[ARG1_I]] : i32 +!CHECK: omp.yield(%[[COMB_VAL_I]] : i32) + +!CHECK-LABEL: @_QPreduction_max_int +!CHECK-SAME: %[[Y_BOX:.*]]: !fir.box> +!CHECK: %[[X_REF:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFreduction_max_intEx"} +!CHECK: omp.parallel +!CHECK: omp.wsloop reduction(@[[MAX_DECLARE_I]] -> %[[X_REF]] : !fir.ref) for +!CHECK: %[[Y_I_REF:.*]] = fir.coordinate_of %[[Y_BOX]] +!CHECK: %[[Y_I:.*]] = fir.load %[[Y_I_REF]] : !fir.ref +!CHECK: omp.reduction %[[Y_I]], %[[X_REF]] : i32, !fir.ref +!CHECK: omp.yield +!CHECK: omp.terminator + +!CHECK-LABEL: @_QPreduction_max_real +!CHECK-SAME: %[[Y_BOX:.*]]: !fir.box> +!CHECK: %[[X_REF:.*]] = fir.alloca f32 {bindc_name = "x", uniq_name = "_QFreduction_max_realEx"} +!CHECK: omp.parallel +!CHECK: omp.wsloop reduction(@[[MAX_DECLARE_F]] -> %[[X_REF]] : !fir.ref) for +!CHECK: %[[Y_I_REF:.*]] = fir.coordinate_of %[[Y_BOX]] +!CHECK: %[[Y_I:.*]] = fir.load %[[Y_I_REF]] : !fir.ref +!CHECK: omp.reduction %[[Y_I]], %[[X_REF]] : f32, !fir.ref +!CHECK: omp.yield +!CHECK: omp.terminator + +subroutine reduction_max_int(y) + integer :: x, y(:) + x = 0 + !$omp parallel + !$omp do reduction(max:x) + do i=1, 100 + x = max(x, y(i)) + end do + !$omp end do + !$omp end parallel + print *, x +end subroutine + +subroutine reduction_max_real(y) + real :: x, y(:) + x = 0.0 + !$omp parallel + !$omp do reduction(max:x) + do i=1, 100 + x = max(y(i), x) + end do + !$omp end do + !$omp end parallel + print *, x +end subroutine