diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h --- a/flang/include/flang/Lower/OpenMP.h +++ b/flang/include/flang/Lower/OpenMP.h @@ -15,6 +15,15 @@ #include +namespace mlir { +class Value; +class Operation; +} // namespace mlir + +namespace fir { +class FirOpBuilder; +} // namespace fir + namespace Fortran { namespace parser { struct OpenMPConstruct; @@ -41,6 +50,11 @@ void genOpenMPReduction(AbstractConverter &, const Fortran::parser::OmpClauseList &clauseList); +void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value, + mlir::Value); + +mlir::Operation *getReductionInChain(mlir::Value, mlir::Value); + } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -1638,37 +1638,14 @@ symVal.getType().cast().getEleTy(); if (!redType.isIntOrIndex()) continue; + for (mlir::OpOperand &use1 : symVal.getUses()) { - if (auto load = mlir::dyn_cast(use1.getOwner())) { - mlir::Value loadVal = load.getRes(); - for (mlir::OpOperand &use2 : loadVal.getUses()) { - if (auto add = mlir::dyn_cast( - use2.getOwner())) { - mlir::Value addRes = add.getResult(); - for (mlir::OpOperand &use3 : addRes.getUses()) { - if (auto store = - mlir::dyn_cast(use3.getOwner())) { - if (store.getMemref() == symVal) { - // Chain found! Now replace load->reduction->store - // with the OpenMP reduction operation. - mlir::OpBuilder::InsertPoint insertPtDel = - firOpBuilder.saveInsertionPoint(); - firOpBuilder.setInsertionPoint(add); - if (add.getLhs() == loadVal) { - firOpBuilder.create( - add.getLoc(), add.getRhs(), symVal); - } else { - firOpBuilder.create( - add.getLoc(), add.getLhs(), symVal); - } - store.erase(); - add.erase(); - load.erase(); - firOpBuilder.restoreInsertionPoint(insertPtDel); - } - } - } - } + + if (auto loadOp = + mlir::dyn_cast(use1.getOwner())) { + mlir::Value loadVal = loadOp.getRes(); + if (auto reductionOp = getReductionInChain(symVal, loadVal)) { + updateReduction(reductionOp, firOpBuilder, loadVal, symVal); } } } @@ -1679,3 +1656,37 @@ } } } + +mlir::Operation *Fortran::lower::getReductionInChain(mlir::Value symVal, + mlir::Value loadVal) { + for (mlir::OpOperand &loadOperands : loadVal.getUses()) { + if (auto reductionOp = loadOperands.getOwner()) { + for (mlir::OpOperand &reductionOperands : reductionOp->getUses()) { + if (auto store = + mlir::dyn_cast(reductionOperands.getOwner())) { + if (store.getMemref() == symVal) { + store.erase(); + return reductionOp; + } + } + } + } + } + return nullptr; +} + +void Fortran::lower::updateReduction(mlir::Operation *op, + fir::FirOpBuilder &firOpBuilder, + mlir::Value loadVal, mlir::Value symVal) { + mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint(); + firOpBuilder.setInsertionPoint(op); + + if (op->getOperand(0) == loadVal) + firOpBuilder.create(op->getLoc(), op->getOperand(1), + symVal); + else + firOpBuilder.create(op->getLoc(), op->getOperand(0), + symVal); + + firOpBuilder.restoreInsertionPoint(insertPtDel); +}