diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h --- a/flang/include/flang/Lower/OpenMP.h +++ b/flang/include/flang/Lower/OpenMP.h @@ -22,6 +22,7 @@ namespace fir { class FirOpBuilder; +class ConvertOp; } // namespace fir namespace Fortran { @@ -50,11 +51,11 @@ void genOpenMPReduction(AbstractConverter &, const Fortran::parser::OmpClauseList &clauseList); +mlir::Operation *findReductionChain(mlir::Value, mlir::Value * = nullptr); +fir::ConvertOp getConvertFromReductionOp(mlir::Operation *, mlir::Value); void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value, - mlir::Value); - -mlir::Operation *getReductionInChain(mlir::Value, mlir::Value); - + mlir::Value, fir::ConvertOp * = nullptr); +void removeStoreOp(mlir::Operation *, mlir::Value); } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -850,7 +850,8 @@ mlir::Location loc) { if (reductionOpName.contains("add")) return 0; - else if (reductionOpName.contains("multiply")) + else if (reductionOpName.contains("multiply") || + reductionOpName.contains("and")) return 1; TODO(loc, "Reduction of some intrinsic operators is not supported"); } @@ -858,7 +859,8 @@ static Value getReductionInitValue(mlir::Location loc, mlir::Type type, llvm::StringRef reductionOpName, fir::FirOpBuilder &builder) { - assert(type.isIntOrIndexOrFloat() && "only integer and float types are currently supported"); + assert(type.isIntOrIndexOrFloat() && + "only integer and float types are currently supported"); if (type.isa()) return builder.create( loc, type, @@ -874,7 +876,8 @@ static Value getReductionOperation(fir::FirOpBuilder &builder, mlir::Type type, mlir::Location loc, mlir::Value op1, mlir::Value op2) { - assert(type.isIntOrIndexOrFloat() && "only integer and float types are currently supported"); + assert(type.isIntOrIndexOrFloat() && + "only integer and float types are currently supported"); if (type.isIntOrIndex()) return builder.create(loc, op1, op2); return builder.create(loc, op1, op2); @@ -898,7 +901,6 @@ modBuilder.create(loc, reductionOpName, type); else return decl; - builder.createBlock(&decl.initializerRegion(), decl.initializerRegion().end(), {type}, {loc}); builder.setInsertionPointToEnd(&decl.initializerRegion().back()); @@ -923,6 +925,9 @@ getReductionOperation( builder, type, loc, op1, op2); break; + case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + reductionOp = builder.create(loc, op1, op2); + break; default: TODO(loc, "Reduction of some intrinsic operators is not supported"); } @@ -1007,6 +1012,8 @@ case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: reductionName = "multiply_reduction"; break; + case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + return "and_reduction"; default: reductionName = "other_reduction"; break; @@ -1115,6 +1122,7 @@ switch (intrinsicOp) { case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: break; default: @@ -1130,6 +1138,8 @@ mlir::Type redType = symVal.getType().cast().getEleTy(); reductionVars.push_back(symVal); + if (redType.isa()) + redType = firOpBuilder.getI1Type(); if (redType.isIntOrIndexOrFloat()) { decl = createReductionDecl( firOpBuilder, getReductionName(intrinsicOp, redType), @@ -1785,11 +1795,10 @@ ompDeclConstruct.u); } -// Generate an OpenMP reduction operation. This implementation finds the chain : -// load reduction var -> reduction_operation -> store reduction var and replaces -// it with the reduction operation. -// TODO: Currently assumes it is an integer addition/multiplication reduction. -// Generalize this for various reduction operation types. +// Generate an OpenMP reduction operation. +// TODO: Currently assumes it is either an integer addition/multiplication +// reduction, or a logical and reduction. Generalize this for various reduction +// operation types. // TODO: Generate the reduction operation during lowering instead of creating // and removing operations since this is not a robust approach. Also, removing // ops in the builder (instead of a rewriter) is probably not the best approach. @@ -1814,6 +1823,7 @@ switch (intrinsicOp) { case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: break; default: continue; @@ -1825,16 +1835,28 @@ mlir::Value reductionVal = converter.getSymbolAddress(*symbol); mlir::Type reductionType = reductionVal.getType().cast().getEleTy(); - if (!reductionType.isIntOrIndexOrFloat()) - continue; + if (intrinsicOp != + Fortran::parser::DefinedOperator::IntrinsicOperator::AND) { + if (!reductionType.isIntOrIndexOrFloat()) + continue; + } for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) { - - if (auto loadOp = - mlir::dyn_cast(reductionValUse.getOwner())) { + if (auto loadOp = mlir::dyn_cast( + reductionValUse.getOwner())) { mlir::Value loadVal = loadOp.getRes(); - if (auto reductionOp = getReductionInChain(reductionVal, loadVal)) { - updateReduction(reductionOp, firOpBuilder, loadVal, reductionVal); + if (intrinsicOp == Fortran::parser::DefinedOperator:: + IntrinsicOperator::AND) { + mlir::Operation *reductionOp = findReductionChain(loadVal); + fir::ConvertOp convertOp = + getConvertFromReductionOp(reductionOp, loadVal); + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal, &convertOp); + removeStoreOp(reductionOp, reductionVal); + } else if (auto reductionOp = + findReductionChain(loadVal, &reductionVal)) { + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal); } } } @@ -1846,19 +1868,20 @@ } } -// Checks whether loadVal is used in an operation, -// the result of which is then stored into reductionVal. -// If yes, then the operation corresponding to the reduction is returned. -// loadVal is assumed to be the value of a load operation -// reductionVal is the results of an OpenMP reduction operation. -mlir::Operation *Fortran::lower::getReductionInChain(mlir::Value reductionVal, - mlir::Value loadVal) { - for (mlir::OpOperand &loadUse : loadVal.getUses()) { - if (auto reductionOp = loadUse.getOwner()) { +mlir::Operation *Fortran::lower::findReductionChain(mlir::Value loadVal, + mlir::Value *reductionVal) { + for (mlir::OpOperand &loadOperand : loadVal.getUses()) { + if (auto reductionOp = loadOperand.getOwner()) { + if (auto convertOp = mlir::dyn_cast(reductionOp)) { + for (mlir::OpOperand &convertOperand : convertOp.getRes().getUses()) { + if (auto reductionOp = convertOperand.getOwner()) + return reductionOp; + } + } for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) { if (auto store = mlir::dyn_cast(reductionOperand.getOwner())) { - if (store.getMemref() == reductionVal) { + if (store.getMemref() == *reductionVal) { store.erase(); return reductionOp; } @@ -1871,16 +1894,53 @@ void Fortran::lower::updateReduction(mlir::Operation *op, fir::FirOpBuilder &firOpBuilder, - mlir::Value loadVal, mlir::Value reductionVal) { + mlir::Value loadVal, + mlir::Value reductionVal, + fir::ConvertOp *convertOp) { mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint(); firOpBuilder.setInsertionPoint(op); - if (op->getOperand(0) == loadVal) - firOpBuilder.create(op->getLoc(), op->getOperand(1), - reductionVal); + mlir::Value reductionOp; + if (convertOp) + reductionOp = convertOp->getOperand(); + else if (op->getOperand(0) == loadVal) + reductionOp = op->getOperand(1); else - firOpBuilder.create(op->getLoc(), op->getOperand(0), - reductionVal); + reductionOp = op->getOperand(0); + firOpBuilder.create(op->getLoc(), reductionOp, + reductionVal); firOpBuilder.restoreInsertionPoint(insertPtDel); } + +// for a logical operator 'op' reduction X = X op Y +// This function returns the operation responsible for converting Y from +// fir.logical<4> to i1 +fir::ConvertOp +Fortran::lower::getConvertFromReductionOp(mlir::Operation *reductionOp, + mlir::Value loadVal) { + for (auto reductionOperand : reductionOp->getOperands()) { + if (auto convertOp = + mlir::dyn_cast(reductionOperand.getDefiningOp())) { + if (convertOp.getOperand() == loadVal) + continue; + return convertOp; + } + } + return nullptr; +} + +void Fortran::lower::removeStoreOp(mlir::Operation *reductionOp, + mlir::Value symVal) { + for (auto reductionOpUse : reductionOp->getUsers()) { + if (auto convertReduction = + mlir::dyn_cast(reductionOpUse)) { + for (auto convertReductionUse : convertReduction.getRes().getUsers()) { + if (auto storeOp = mlir::dyn_cast(convertReductionUse)) { + if (storeOp.getMemref() == symVal) + storeOp.erase(); + } + } + } + } +} diff --git a/flang/test/Lower/OpenMP/Todo/reduction-and.f90 b/flang/test/Lower/OpenMP/Todo/reduction-and.f90 deleted file mode 100644 --- a/flang/test/Lower/OpenMP/Todo/reduction-and.f90 +++ /dev/null @@ -1,15 +0,0 @@ -! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s -! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s - -! CHECK: not yet implemented: Reduction of some intrinsic operators is not supported -subroutine reduction_and(y) - logical :: x, y(100) - !$omp parallel - !$omp do reduction(.and.:x) - do i=1, 100 - x = x .and. y(i) - end do - !$omp end do - !$omp end parallel - print *, x -end subroutine diff --git a/flang/test/Lower/OpenMP/wsloop-reduction-logical-and.f90 b/flang/test/Lower/OpenMP/wsloop-reduction-logical-and.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/wsloop-reduction-logical-and.f90 @@ -0,0 +1,133 @@ +! RUN: bbc -emit-fir -fopenmp %s -o - | FileCheck %s +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +!CHECK-LABEL: omp.reduction.declare +!CHECK-SAME: @[[RED_NAME:.*]] : i1 init { +!CHECK: ^bb0(%{{.*}}: i1): +!CHECK: %true = arith.constant true +!CHECK: omp.yield(%true : i1) +!CHECK: } combiner { +!CHECK: ^bb0(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1): +!CHECK: %[[RES:.*]] = arith.andi %[[ARG0]], %[[ARG1]] : i1 +!CHECK: omp.yield(%[[RES]] : i1) +!CHECK: } + +!CHECK-LABEL: func.func @_QPsimple_reduction( +!CHECK-SAME: %[[ARRAY:.*]]: !fir.ref>> {fir.bindc_name = "y"}) { +!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_reductionEi"} +!CHECK: %[[XREF:.*]] = fir.alloca !fir.logical<4> {bindc_name = "x", uniq_name = "_QFsimple_reductionEx"} +!CHECK: omp.parallel +!CHECK: %[[I_PVT_REF:.*]] = fir.alloca i32 {adapt.valuebyref, pinned} +!CHECK: %[[C1_1:.*]] = arith.constant 1 : i32 +!CHECK: %[[C100:.*]] = arith.constant 100 : i32 +!CHECK: %[[C1_2:.*]] = arith.constant 1 : i32 +!CHECK: omp.wsloop reduction(@[[RED_NAME]] -> %[[XREF]] : !fir.ref>) for (%[[IVAL:.*]]) : i32 = (%[[C1_1]]) to (%[[C100]]) inclusive step (%[[C1_2]]) { +!CHECK: fir.store %[[IVAL]] to %[[I_PVT_REF]] : !fir.ref +!CHECK: %[[I_PVT_VAL:.*]] = fir.load %[[I_PVT_REF]] : !fir.ref +!CHECK: %[[CONVI_64:.*]] = fir.convert %[[I_PVT_VAL]] : (i32) -> i64 +!CHECK: %[[C1_64:.*]] = arith.constant 1 : i64 +!CHECK: %[[SUBI:.*]] = arith.subi %[[CONVI_64]], %[[C1_64]] : i64 +!CHECK: %[[Y_PVT_REF:.*]] = fir.coordinate_of %[[ARRAY]], %[[SUBI]] : (!fir.ref>>, i64) -> !fir.ref> +!CHECK: %[[YVAL:.*]] = fir.load %[[Y_PVT_REF]] : !fir.ref> +!CHECK: omp.reduction %[[YVAL]], %[[XREF]] : !fir.ref> +!CHECK: omp.yield +!CHECK: omp.terminator +!CHECK: return +subroutine simple_reduction(y) + logical :: x, y(100) + x = .true. + !$omp parallel + !$omp do reduction(.and.:x) + do i=1, 100 + x = x .and. y(i) + end do + !$omp end do + !$omp end parallel +end subroutine + +!CHECK-LABEL: func.func @_QPsimple_reduction_switch_order( +!CHECK-SAME: %[[ARRAY:.*]]: !fir.ref>> {fir.bindc_name = "y"}) { +!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_reduction_switch_orderEi"} +!CHECK: %[[XREF:.*]] = fir.alloca !fir.logical<4> {bindc_name = "x", uniq_name = "_QFsimple_reduction_switch_orderEx"} +!CHECK: omp.parallel +!CHECK: %[[I_PVT_REF:.*]] = fir.alloca i32 {adapt.valuebyref, pinned} +!CHECK: %[[C1_1:.*]] = arith.constant 1 : i32 +!CHECK: %[[C100:.*]] = arith.constant 100 : i32 +!CHECK: %[[C1_2:.*]] = arith.constant 1 : i32 +!CHECK: omp.wsloop reduction(@[[RED_NAME]] -> %[[XREF]] : !fir.ref>) for (%[[IVAL:.*]]) : i32 = (%[[C1_1]]) to (%[[C100]]) inclusive step (%[[C1_2]]) { +!CHECK: fir.store %[[IVAL]] to %[[I_PVT_REF]] : !fir.ref +!CHECK: %[[I_PVT_VAL:.*]] = fir.load %[[I_PVT_REF]] : !fir.ref +!CHECK: %[[CONVI_64:.*]] = fir.convert %[[I_PVT_VAL]] : (i32) -> i64 +!CHECK: %[[C1_64:.*]] = arith.constant 1 : i64 +!CHECK: %[[SUBI:.*]] = arith.subi %[[CONVI_64]], %[[C1_64]] : i64 +!CHECK: %[[Y_PVT_REF:.*]] = fir.coordinate_of %[[ARRAY]], %[[SUBI]] : (!fir.ref>>, i64) -> !fir.ref> +!CHECK: %[[YVAL:.*]] = fir.load %[[Y_PVT_REF]] : !fir.ref> +!CHECK: omp.reduction %[[YVAL]], %[[XREF]] : !fir.ref> +!CHECK: omp.yield +!CHECK: omp.terminator +!CHECK: return +subroutine simple_reduction_switch_order(y) + logical :: x, y(100) + x = .true. + !$omp parallel + !$omp do reduction(.and.:x) + do i=1, 100 + x = y(i) .and. x + end do + !$omp end do + !$omp end parallel +end subroutine + +!CHECK-LABEL: func.func @_QPmultiple_reductions +!CHECK-SAME %[[ARRAY:.*]]: !fir.ref>> {fir.bindc_name = "w"}) { +!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFmultiple_reductionsEi"} +!CHECK: %[[XREF:.*]] = fir.alloca !fir.logical<4> {bindc_name = "x", uniq_name = "_QFmultiple_reductionsEx"} +!CHECK: %[[YREF:.*]] = fir.alloca !fir.logical<4> {bindc_name = "y", uniq_name = "_QFmultiple_reductionsEy"} +!CHECK: %[[ZREF:.*]] = fir.alloca !fir.logical<4> {bindc_name = "z", uniq_name = "_QFmultiple_reductionsEz"} +!CHECK: omp.parallel +!CHECK: %[[I_PVT_REF:.*]] = fir.alloca i32 {adapt.valuebyref, pinned} +!CHECK: %[[C1_1:.*]] = arith.constant 1 : i32 +!CHECK: %[[C100:.*]] = arith.constant 100 : i32 +!CHECK: %[[C1_2:.*]] = arith.constant 1 : i32 +!CHECK: omp.wsloop reduction(@[[RED_NAME]] -> %[[XREF]] : !fir.ref>, @[[RED_NAME]] -> %[[YREF]] : !fir.ref>, @[[RED_NAME]] -> %[[ZREF]] : !fir.ref>) for (%[[IVAL:.*]]) : i32 = (%[[C1_1]]) to (%[[C100]]) inclusive step (%[[C1_2]]) { +!CHECK: fir.store %[[IVAL]] to %[[I_PVT_REF]] : !fir.ref +!CHECK: %[[I_PVT_VAL1:.*]] = fir.load %[[I_PVT_REF]] : !fir.ref +!CHECK: %[[CONVI_64_1:.*]] = fir.convert %[[I_PVT_VAL1]] : (i32) -> i64 +!CHECK: %[[C1_64:.*]] = arith.constant 1 : i64 +!CHECK: %[[SUBI_1:.*]] = arith.subi %[[CONVI_64_1]], %[[C1_64]] : i64 +!CHECK: %[[W_PVT_REF_1:.*]] = fir.coordinate_of %[[ARRAY]], %[[SUBI_1]] : (!fir.ref>>, i64) -> !fir.ref> +!CHECK: %[[WVAL:.*]] = fir.load %[[W_PVT_REF_1]] : !fir.ref> +!CHECK: omp.reduction %[[WVAL]], %[[XREF]] : !fir.ref> +!CHECK: %[[I_PVT_VAL2:.*]] = fir.load %[[I_PVT_REF]] : !fir.ref +!CHECK: %[[CONVI_64_2:.*]] = fir.convert %[[I_PVT_VAL2]] : (i32) -> i64 +!CHECK: %[[C1_64:.*]] = arith.constant 1 : i64 +!CHECK: %[[SUBI_2:.*]] = arith.subi %[[CONVI_64_2]], %[[C1_64]] : i64 +!CHECK: %[[W_PVT_REF_2:.*]] = fir.coordinate_of %[[ARRAY]], %[[SUBI_2]] : (!fir.ref>>, i64) -> !fir.ref> +!CHECK: %[[WVAL:.*]] = fir.load %[[W_PVT_REF_2]] : !fir.ref> +!CHECK: omp.reduction %[[WVAL]], %[[YREF]] : !fir.ref> +!CHECK: %[[I_PVT_VAL3:.*]] = fir.load %[[I_PVT_REF]] : !fir.ref +!CHECK: %[[CONVI_64_3:.*]] = fir.convert %[[I_PVT_VAL3]] : (i32) -> i64 +!CHECK: %[[C1_64:.*]] = arith.constant 1 : i64 +!CHECK: %[[SUBI_3:.*]] = arith.subi %[[CONVI_64_3]], %[[C1_64]] : i64 +!CHECK: %[[W_PVT_REF_3:.*]] = fir.coordinate_of %[[ARRAY]], %[[SUBI_3]] : (!fir.ref>>, i64) -> !fir.ref> +!CHECK: %[[WVAL:.*]] = fir.load %[[W_PVT_REF_3]] : !fir.ref> +!CHECK: omp.reduction %[[WVAL]], %[[ZREF]] : !fir.ref> +!CHECK: omp.yield +!CHECK: omp.terminator +!CHECK: return +subroutine multiple_reductions(w) + logical :: x,y,z,w(100) + x = .true. + y = .true. + z = .true. + !$omp parallel + !$omp do reduction(.and.:x,y,z) + do i=1, 100 + x = x .and. w(i) + y = y .and. w(i) + z = z .and. w(i) + end do + !$omp end do + !$omp end parallel +end subroutine +