diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h --- a/flang/include/flang/Lower/OpenMP.h +++ b/flang/include/flang/Lower/OpenMP.h @@ -15,6 +15,14 @@ #include +namespace mlir { +class Value; +} // namespace mlir + +namespace fir { +class FirOpBuilder; +} // namespace fir + namespace Fortran { namespace parser { struct OpenMPConstruct; @@ -41,6 +49,9 @@ void genOpenMPReduction(AbstractConverter &, const Fortran::parser::OmpClauseList &clauseList); +template +void updateReduction(OpType &op, fir::FirOpBuilder &firOpBuilder, + mlir::Value loadVal, mlir::Value symVal); } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -783,14 +783,22 @@ } } +int getOperationIdentity(llvm::StringRef name, mlir::Location loc) { + if (name.contains("add")) + return 0; + else if (name.contains("multiply")) + return 1; + TODO(loc, "Reduction of some intrinsic operators is not supported"); +} + /// Creates an OpenMP reduction declaration and inserts it into the provided /// symbol table. The declaration has a constant initializer with the neutral /// value `initValue`, and the reduction combiner carried over from `reduce`. /// TODO: Generalize this for non-integer types, add atomic region. -static omp::ReductionDeclareOp createReductionDecl(fir::FirOpBuilder &builder, - llvm::StringRef name, - mlir::Type type, - mlir::Location loc) { +static omp::ReductionDeclareOp createReductionDecl( + fir::FirOpBuilder &builder, llvm::StringRef name, + Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, + mlir::Type type, mlir::Location loc) { OpBuilder::InsertionGuard guard(builder); mlir::ModuleOp module = builder.getModule(); mlir::OpBuilder modBuilder(module.getBodyRegion()); @@ -804,7 +812,7 @@ {type}, {loc}); builder.setInsertionPointToEnd(&decl.initializerRegion().back()); Value init = builder.create( - loc, type, builder.getIntegerAttr(type, 0)); + loc, type, builder.getIntegerAttr(type, getOperationIdentity(name, loc))); builder.create(loc, init); builder.createBlock(&decl.reductionRegion(), decl.reductionRegion().end(), @@ -812,8 +820,20 @@ builder.setInsertionPointToEnd(&decl.reductionRegion().back()); mlir::Value op1 = decl.reductionRegion().front().getArgument(0); mlir::Value op2 = decl.reductionRegion().front().getArgument(1); - Value addRes = builder.create(loc, op1, op2); - builder.create(loc, addRes); + + Value res; + switch (intrinsicOp) { + case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + res = builder.create(loc, op1, op2); + break; + case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + res = builder.create(loc, op1, op2); + break; + default: + TODO(loc, "Reduction of some intrinsic operators is not supported"); + } + + builder.create(loc, res); return decl; } @@ -885,10 +905,18 @@ Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, mlir::Type ty) { std::string reductionName; - if (intrinsicOp == Fortran::parser::DefinedOperator::IntrinsicOperator::Add) + + switch (intrinsicOp) { + case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: reductionName = "add_reduction"; - else + break; + case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + reductionName = "multiply_reduction"; + break; + default: reductionName = "other_reduction"; + break; + } return (llvm::Twine(reductionName) + (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) + @@ -990,10 +1018,16 @@ const auto &intrinsicOp{ std::get( redDefinedOp->u)}; - if (intrinsicOp != - Fortran::parser::DefinedOperator::IntrinsicOperator::Add) + switch (intrinsicOp) { + case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + break; + + default: TODO(currentLocation, "Reduction of some intrinsic operators is not supported"); + break; + } for (const auto &ompObject : objectList.v) { if (const auto *name{ Fortran::parser::Unwrap(ompObject)}) { @@ -1005,7 +1039,7 @@ if (redType.isIntOrIndex()) { decl = createReductionDecl( firOpBuilder, getReductionName(intrinsicOp, redType), - redType, currentLocation); + intrinsicOp, redType, currentLocation); } else { TODO(currentLocation, "Reduction of some types is not supported"); @@ -1604,8 +1638,8 @@ // Generate an OpenMP reduction operation. This implementation finds the chain : // load reduction var -> reduction_operation -> store reduction var and replaces // it with the reduction operation. -// TODO: Currently assumes it is an integer addition reduction. Generalize this -// for various reduction operation types. +// TODO: Currently assumes it is an integer addition/multiplication reduction. +// Generalize this for various reduction operation types. // TODO: Generate the reduction operation during lowering instead of creating // and removing operations since this is not a robust approach. Also, removing // ops in the builder (instead of a rewriter) is probably not the best approach. @@ -1626,9 +1660,14 @@ const auto &intrinsicOp{ std::get( reductionOp->u)}; - if (intrinsicOp != - Fortran::parser::DefinedOperator::IntrinsicOperator::Add) + + switch (intrinsicOp) { + case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + break; + default: continue; + } for (const auto &ompObject : objectList.v) { if (const auto *name{ Fortran::parser::Unwrap(ompObject)}) { @@ -1642,33 +1681,17 @@ if (auto load = mlir::dyn_cast(use1.getOwner())) { mlir::Value loadVal = load.getRes(); for (mlir::OpOperand &use2 : loadVal.getUses()) { - if (auto add = mlir::dyn_cast( - use2.getOwner())) { - mlir::Value addRes = add.getResult(); - for (mlir::OpOperand &use3 : addRes.getUses()) { - if (auto store = - mlir::dyn_cast(use3.getOwner())) { - if (store.getMemref() == symVal) { - // Chain found! Now replace load->reduction->store - // with the OpenMP reduction operation. - mlir::OpBuilder::InsertPoint insertPtDel = - firOpBuilder.saveInsertionPoint(); - firOpBuilder.setInsertionPoint(add); - if (add.getLhs() == loadVal) { - firOpBuilder.create( - add.getLoc(), add.getRhs(), symVal); - } else { - firOpBuilder.create( - add.getLoc(), add.getLhs(), symVal); - } - store.erase(); - add.erase(); - load.erase(); - firOpBuilder.restoreInsertionPoint(insertPtDel); - } - } - } - } + if (isa(use2.getOwner())) { + auto add = + mlir::dyn_cast(use2.getOwner()); + updateReduction(add, firOpBuilder, loadVal, symVal); + } else if (isa(use2.getOwner())) { + auto mul = + mlir::dyn_cast(use2.getOwner()); + updateReduction(mul, firOpBuilder, loadVal, symVal); + } else + continue; + load.erase(); } } } @@ -1679,3 +1702,31 @@ } } } + +template +void Fortran::lower::updateReduction(OpType &op, + fir::FirOpBuilder &firOpBuilder, + mlir::Value loadVal, mlir::Value symVal) { + mlir::Value opRes = op.getResult(); + for (mlir::OpOperand &use3 : opRes.getUses()) { + if (auto store = mlir::dyn_cast(use3.getOwner())) { + if (store.getMemref() == symVal) { + // Chain found! Now replace load->reduction->store + // with the OpenMP reduction operation. + mlir::OpBuilder::InsertPoint insertPtDel = + firOpBuilder.saveInsertionPoint(); + firOpBuilder.setInsertionPoint(op); + if (op.getLhs() == loadVal) { + firOpBuilder.create(op.getLoc(), op.getRhs(), + symVal); + } else { + firOpBuilder.create(op.getLoc(), op.getLhs(), + symVal); + } + store.erase(); + op.erase(); + firOpBuilder.restoreInsertionPoint(insertPtDel); + } + } + } +} diff --git a/flang/test/Lower/OpenMP/Todo/reduction-multiply.f90 b/flang/test/Lower/OpenMP/Todo/reduction-multiply.f90 deleted file mode 100644 --- a/flang/test/Lower/OpenMP/Todo/reduction-multiply.f90 +++ /dev/null @@ -1,15 +0,0 @@ -! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s -! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s - -! CHECK: not yet implemented: Reduction of some intrinsic operators is not supported -subroutine reduction_multiply - integer :: x - !$omp parallel - !$omp do reduction(*:x) - do i=1, 100 - x = x * i - end do - !$omp end do - !$omp end parallel - print *, x -end subroutine diff --git a/flang/test/Lower/OpenMP/reduction-multiply.f90 b/flang/test/Lower/OpenMP/reduction-multiply.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/reduction-multiply.f90 @@ -0,0 +1,144 @@ +! RUN: bbc -emit-fir -fopenmp %s -o - | FileCheck %s +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +!CHECK-LABEL: omp.reduction.declare +!CHECK-SAME: @[[RED_I64_NAME:.*]] : i64 init { +!CHECK: ^bb0(%{{.*}}: i64): +!CHECK: %[[C1_1:.*]] = arith.constant 1 : i64 +!CHECK: omp.yield(%[[C1_1]] : i64) +!CHECK: } combiner { +!CHECK: ^bb0(%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64): +!CHECK: %[[RES:.*]] = arith.muli %[[ARG0]], %[[ARG1]] : i64 +!CHECK: omp.yield(%[[RES]] : i64) +!CHECK: } + +!CHECK-LABEL: omp.reduction.declare +!CHECK-SAME: @[[RED_I32_NAME:.*]] : i32 init { +!CHECK: ^bb0(%{{.*}}: i32): +!CHECK: %[[C1_1:.*]] = arith.constant 1 : i32 +!CHECK: omp.yield(%[[C1_1]] : i32) +!CHECK: } combiner { +!CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32): +!CHECK: %[[RES:.*]] = arith.muli %[[ARG0]], %[[ARG1]] : i32 +!CHECK: omp.yield(%[[RES]] : i32) +!CHECK: } + +!CHECK-LABEL: func.func @_QPsimple_reduction +!CHECK: %[[XREF:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_reductionEx"} +!CHECK: %[[C1_2:.*]] = arith.constant 1 : i32 +!CHECK: fir.store %[[C1_2]] to %[[XREF]] : !fir.ref +!CHECK: omp.parallel +!CHECK: %[[I_PVT_REF:.*]] = fir.alloca i32 {adapt.valuebyref, pinned} +!CHECK: %[[C1_1:.*]] = arith.constant 1 : i32 +!CHECK: %[[C10:.*]] = arith.constant 10 : i32 +!CHECK: %[[C1_2:.*]] = arith.constant 1 : i32 +!CHECK: omp.wsloop reduction(@[[RED_I32_NAME]] -> %[[XREF]] : !fir.ref) for (%[[IVAL:.*]]) : i32 = (%[[C1_1]]) to (%[[C10]]) inclusive step (%[[C1_2]]) +!CHECK: fir.store %[[IVAL]] to %[[I_PVT_REF]] : !fir.ref +!CHECK: %[[I_PVT_VAL:.*]] = fir.load %[[I_PVT_REF]] : !fir.ref +!CHECK: omp.reduction %[[I_PVT_VAL]], %[[XREF]] : !fir.ref +!CHECK: omp.yield +!CHECK: omp.terminator +!CHECK: return + +subroutine simple_reduction + integer :: x + x = 1 + !$omp parallel + !$omp do reduction(*:x) + do i=1, 10 + x = x * i + end do + !$omp end do + !$omp end parallel +end subroutine + +!CHECK-LABEL: func.func @_QPsimple_reduction_switch_order +!CHECK: %[[XREF:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_reduction_switch_orderEx"} +!CHECK: %[[C1_2:.*]] = arith.constant 1 : i32 +!CHECK: fir.store %[[C1_2]] to %[[XREF]] : !fir.ref +!CHECK: omp.parallel +!CHECK: %[[I_PVT_REF:.*]] = fir.alloca i32 {adapt.valuebyref, pinned} +!CHECK: %[[C1_1:.*]] = arith.constant 1 : i32 +!CHECK: %[[C10:.*]] = arith.constant 10 : i32 +!CHECK: %[[C1_2:.*]] = arith.constant 1 : i32 +!CHECK: omp.wsloop reduction(@[[RED_I32_NAME]] -> %[[XREF]] : !fir.ref) for (%[[IVAL:.*]]) : i32 = (%[[C1_1]]) to (%[[C10]]) inclusive step (%[[C1_2]]) +!CHECK: fir.store %[[IVAL]] to %[[I_PVT_REF]] : !fir.ref +!CHECK: %[[I_PVT_VAL:.*]] = fir.load %[[I_PVT_REF]] : !fir.ref +!CHECK: omp.reduction %[[I_PVT_VAL]], %[[XREF]] : !fir.ref +!CHECK: omp.yield +!CHECK: omp.terminator +!CHECK: return + +subroutine simple_reduction_switch_order + integer :: x + x = 1 + !$omp parallel + !$omp do reduction(*:x) + do i=1, 10 + x = i * x + end do + !$omp end do + !$omp end parallel +end subroutine + +!CHECK-LABEL: func.func @_QPmultiple_reductions_same_type +!CHECK: %[[XREF:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFmultiple_reductions_same_typeEx"} +!CHECK: %[[YREF:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFmultiple_reductions_same_typeEy"} +!CHECK: %[[ZREF:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFmultiple_reductions_same_typeEz"} +!CHECK: omp.parallel +!CHECK: %[[I_PVT_REF:.*]] = fir.alloca i32 {adapt.valuebyref, pinned} +!CHECK: omp.wsloop reduction(@[[RED_I32_NAME]] -> %[[XREF]] : !fir.ref, @[[RED_I32_NAME]] -> %[[YREF]] : !fir.ref, @[[RED_I32_NAME]] -> %[[ZREF]] : !fir.ref) for (%[[IVAL]]) : i32 +!CHECK: fir.store %[[IVAL]] to %[[I_PVT_REF]] : !fir.ref +!CHECK: %[[I_PVT_VAL1:.*]] = fir.load %[[I_PVT_REF]] : !fir.ref +!CHECK: omp.reduction %[[I_PVT_VAL1]], %[[XREF]] : !fir.ref +!CHECK: %[[I_PVT_VAL2:.*]] = fir.load %[[I_PVT_REF]] : !fir.ref +!CHECK: omp.reduction %[[I_PVT_VAL2]], %[[YREF]] : !fir.ref +!CHECK: %[[I_PVT_VAL3:.*]] = fir.load %[[I_PVT_REF]] : !fir.ref +!CHECK: omp.reduction %[[I_PVT_VAL3]], %[[ZREF]] : !fir.ref +!CHECK: omp.yield +!CHECK: omp.terminator +!CHECK: return + +subroutine multiple_reductions_same_type + integer :: x,y,z + x = 1 + y = 1 + z = 1 + !$omp parallel + !$omp do reduction(*:x,y,z) + do i=1, 10 + x = x * i + y = y * i + z = z * i + end do + !$omp end do + !$omp end parallel +end subroutine + +!CHECK-LABEL: func.func @_QPmultiple_reductions_different_type +!CHECK: %[[XREF:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFmultiple_reductions_different_typeEx"} +!CHECK: %[[YREF:.*]] = fir.alloca i64 {bindc_name = "y", uniq_name = "_QFmultiple_reductions_different_typeEy"} +!CHECK: omp.parallel +!CHECK: %[[I_PVT_REF:.*]] = fir.alloca i32 {adapt.valuebyref, pinned} +!CHECK: omp.wsloop reduction(@[[RED_I32_NAME]] -> %[[XREF]] : !fir.ref, @[[RED_I64_NAME]] -> %[[YREF]] : !fir.ref) for (%[[IVAL:.*]]) : i32 +!CHECK: fir.store %[[IVAL]] to %[[I_PVT_REF]] : !fir.ref +!CHECK: %[[C2_32:.*]] = arith.constant 2 : i32 +!CHECK: omp.reduction %[[C2_32]], %[[XREF]] : !fir.ref +!CHECK: %[[C2_64:.*]] = arith.constant 2 : i64 +!CHECK: omp.reduction %[[C2_64]], %[[YREF]] : !fir.ref +!CHECK: omp.yield +!CHECK: omp.terminator +!CHECK: return + +subroutine multiple_reductions_different_type + integer :: x + integer(kind=8) :: y + !$omp parallel + !$omp do reduction(*:x,y) + do i=1, 10 + x = x * 2_4 + y = y * 2_8 + end do + !$omp end do + !$omp end parallel +end subroutine