diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h --- a/flang/include/flang/Lower/OpenMP.h +++ b/flang/include/flang/Lower/OpenMP.h @@ -15,6 +15,15 @@ #include +namespace mlir { +class Value; +class Operation; +} // namespace mlir + +namespace fir { +class FirOpBuilder; +} // namespace fir + namespace Fortran { namespace parser { struct OpenMPConstruct; @@ -41,6 +50,11 @@ void genOpenMPReduction(AbstractConverter &, const Fortran::parser::OmpClauseList &clauseList); +void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value, + mlir::Value); + +mlir::Operation *getReductionInChain(mlir::Value, mlir::Value); + } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -1633,42 +1633,19 @@ if (const auto *name{ Fortran::parser::Unwrap(ompObject)}) { if (const auto *symbol{name->symbol}) { - mlir::Value symVal = converter.getSymbolAddress(*symbol); - mlir::Type redType = - symVal.getType().cast().getEleTy(); - if (!redType.isIntOrIndex()) + mlir::Value reductionVal = converter.getSymbolAddress(*symbol); + mlir::Type reductionType = + reductionVal.getType().cast().getEleTy(); + if (!reductionType.isIntOrIndex()) continue; - for (mlir::OpOperand &use1 : symVal.getUses()) { - if (auto load = mlir::dyn_cast(use1.getOwner())) { - mlir::Value loadVal = load.getRes(); - for (mlir::OpOperand &use2 : loadVal.getUses()) { - if (auto add = mlir::dyn_cast( - use2.getOwner())) { - mlir::Value addRes = add.getResult(); - for (mlir::OpOperand &use3 : addRes.getUses()) { - if (auto store = - mlir::dyn_cast(use3.getOwner())) { - if (store.getMemref() == symVal) { - // Chain found! Now replace load->reduction->store - // with the OpenMP reduction operation. - mlir::OpBuilder::InsertPoint insertPtDel = - firOpBuilder.saveInsertionPoint(); - firOpBuilder.setInsertionPoint(add); - if (add.getLhs() == loadVal) { - firOpBuilder.create( - add.getLoc(), add.getRhs(), symVal); - } else { - firOpBuilder.create( - add.getLoc(), add.getLhs(), symVal); - } - store.erase(); - add.erase(); - load.erase(); - firOpBuilder.restoreInsertionPoint(insertPtDel); - } - } - } - } + + for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) { + + if (auto loadOp = + mlir::dyn_cast(reductionValUse.getOwner())) { + mlir::Value loadVal = loadOp.getRes(); + if (auto reductionOp = getReductionInChain(reductionVal, loadVal)) { + updateReduction(reductionOp, firOpBuilder, loadVal, reductionVal); } } } @@ -1679,3 +1656,42 @@ } } } + +// Checks whether loadVal is used in an operation, +// the result of which is then stored into reductionVal. +// If yes, then the operation corresponding to the reduction is returned. +// loadVal is assumed to be the value of a load operation +// reductionVal is the results of an OpenMP reduction operation. +mlir::Operation *Fortran::lower::getReductionInChain(mlir::Value reductionVal, + mlir::Value loadVal) { + for (mlir::OpOperand &loadUse : loadVal.getUses()) { + if (auto reductionOp = loadUse.getOwner()) { + for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) { + if (auto store = + mlir::dyn_cast(reductionOperand.getOwner())) { + if (store.getMemref() == reductionVal) { + store.erase(); + return reductionOp; + } + } + } + } + } + return nullptr; +} + +void Fortran::lower::updateReduction(mlir::Operation *op, + fir::FirOpBuilder &firOpBuilder, + mlir::Value loadVal, mlir::Value reductionVal) { + mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint(); + firOpBuilder.setInsertionPoint(op); + + if (op->getOperand(0) == loadVal) + firOpBuilder.create(op->getLoc(), op->getOperand(1), + reductionVal); + else + firOpBuilder.create(op->getLoc(), op->getOperand(0), + reductionVal); + + firOpBuilder.restoreInsertionPoint(insertPtDel); +}