diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h --- a/flang/include/flang/Lower/OpenMP.h +++ b/flang/include/flang/Lower/OpenMP.h @@ -38,6 +38,8 @@ const parser::OpenMPDeclarativeConstruct &); int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList); void genThreadprivateOp(AbstractConverter &, const pft::Variable &); +void genOpenMPReduction(AbstractConverter &, + const Fortran::parser::OmpClauseList &clauseList); } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -1630,11 +1630,12 @@ // no collapse requested. Fortran::lower::pft::Evaluation *curEval = &getEval(); + const Fortran::parser::OmpClauseList *loopOpClauseList = nullptr; if (ompLoop) { - const auto &wsLoopOpClauseList = std::get( + loopOpClauseList = &std::get( std::get(ompLoop->t).t); int64_t collapseValue = - Fortran::lower::getCollapseValue(wsLoopOpClauseList); + Fortran::lower::getCollapseValue(*loopOpClauseList); curEval = &curEval->getFirstNestedEvaluation(); for (int64_t i = 1; i < collapseValue; i++) { @@ -1644,6 +1645,10 @@ for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) genFIR(e); + + if (ompLoop) + genOpenMPReduction(*this, *loopOpClauseList); + localSymbols.popScope(); builder->restoreInsertionPoint(insertPt); } diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -698,6 +698,38 @@ } } +/// Creates an OpenMP reduction declaration and inserts it into the provided +/// symbol table. The declaration has a constant initializer with the neutral +/// value `initValue`, and the reduction combiner carried over from `reduce`. +/// TODO: Generalize this for non-integer types, add atomic region. +static omp::ReductionDeclareOp createReductionDecl(fir::FirOpBuilder &builder, + llvm::StringRef name, + mlir::Type type, + mlir::Location loc) { + OpBuilder::InsertionGuard guard(builder); + mlir::ModuleOp module = builder.getModule(); + mlir::OpBuilder modBuilder(module.getBodyRegion()); + auto decl = module.lookupSymbol(name); + if (!decl) + decl = modBuilder.create(loc, name, type); + + builder.createBlock(&decl.initializerRegion(), decl.initializerRegion().end(), + {type}, {loc}); + builder.setInsertionPointToEnd(&decl.initializerRegion().back()); + Value init = builder.create( + loc, type, builder.getIntegerAttr(type, 0)); + builder.create(loc, init); + + builder.createBlock(&decl.reductionRegion(), decl.reductionRegion().end(), + {type, type}, {loc, loc}); + builder.setInsertionPointToEnd(&decl.reductionRegion().back()); + mlir::Value op1 = decl.reductionRegion().front().getArgument(0); + mlir::Value op2 = decl.reductionRegion().front().getArgument(1); + Value addRes = builder.create(loc, op1, op2); + builder.create(loc, addRes); + return decl; +} + static mlir::omp::ScheduleModifier translateModifier(const Fortran::parser::OmpScheduleModifierType &m) { switch (m.v) { @@ -773,6 +805,7 @@ mlir::Value scheduleChunkClauseOperand, ifClauseOperand; mlir::Attribute scheduleClauseOperand, noWaitClauseOperand, orderedClauseOperand, orderClauseOperand; + SmallVector reductionDeclSymbols; Fortran::lower::StatementContext stmtCtx; const auto &loopOpClauseList = std::get( std::get(loopConstruct.t).t); @@ -841,6 +874,45 @@ } else if (const auto &ifClause = std::get_if(&clause.u)) { ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause); + } else if (const auto &reductionClause = + std::get_if( + &clause.u)) { + omp::ReductionDeclareOp decl; + const auto &redOperator{std::get( + reductionClause->v.t)}; + const auto &objectList{ + std::get(reductionClause->v.t)}; + if (const auto &redDefinedOp = + std::get_if(&redOperator.u)) { + for (const auto &ompObject : objectList.v) { + if (const auto *name{ + Fortran::parser::Unwrap(ompObject)}) { + if (const auto *symbol{name->symbol}) { + mlir::Value symVal = converter.getSymbolAddress(*symbol); + reductionVars.push_back(symVal); + const auto &intrinsicOp{ + std::get( + redDefinedOp->u)}; + if (intrinsicOp == + Fortran::parser::DefinedOperator::IntrinsicOperator::Add) { + // TODO: Remove hardcoding of reduction names. + decl = createReductionDecl( + firOpBuilder, "add_reduction", + symVal.getType().cast().getEleTy(), + currentLocation); + } else { + TODO(currentLocation, + "Reduction of some intrinsic operators not supported"); + } + reductionDeclSymbols.push_back(SymbolRefAttr::get( + firOpBuilder.getContext(), decl.sym_name())); + } + } + } + } else { + TODO(currentLocation, + "OMPC_Reduction of intrinsic procedures not supported"); + } } } @@ -873,7 +945,11 @@ // 2. order auto wsLoopOp = firOpBuilder.create( currentLocation, lowerBound, upperBound, step, linearVars, linearStepVars, - reductionVars, /*reductions=*/nullptr, + reductionVars, + reductionDeclSymbols.empty() + ? nullptr + : mlir::ArrayAttr::get(firOpBuilder.getContext(), + reductionDeclSymbols), scheduleClauseOperand.dyn_cast_or_null(), scheduleChunkClauseOperand, /*schedule_modifiers=*/nullptr, /*simd_modifier=*/nullptr, @@ -1410,3 +1486,68 @@ }, ompDeclConstruct.u); } + +// Find a chain load reduction var -> reduction_operation -> store reduction var +// and replace it with the reduction operation. +// TODO: Currently assumes it is an integer addition reduction. Generalize this +// for various reduction operation types. Add TODOs for unhandled cases. +// CHECK: Whether there are better approaches. Creating and removing operations +// does not look like a robust approach. Also, removing ops in the builder +// (instead of a rewriter) does not look great. +void Fortran::lower::genOpenMPReduction( + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpClauseList &clauseList) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + for (const auto &clause : clauseList.v) { + if (const auto &reductionClause = + std::get_if(&clause.u)) { + mlir::omp::ReductionDeclareOp decl; + const auto &redOperator{std::get( + reductionClause->v.t)}; + const auto &objectList{ + std::get(reductionClause->v.t)}; + if (std::get_if(&redOperator.u)) { + for (const auto &ompObject : objectList.v) { + if (const auto *name{ + Fortran::parser::Unwrap(ompObject)}) { + if (const auto *symbol{name->symbol}) { + mlir::Value symVal = converter.getSymbolAddress(*symbol); + for (mlir::OpOperand &use1 : symVal.getUses()) { + if (auto load = mlir::dyn_cast(use1.getOwner())) { + mlir::Value loadVal = load.getRes(); + for (mlir::OpOperand &use2 : loadVal.getUses()) { + if (auto add = mlir::dyn_cast( + use2.getOwner())) { + mlir::Value addRes = add.getResult(); + for (mlir::OpOperand &use3 : addRes.getUses()) { + if (auto store = + mlir::dyn_cast(use3.getOwner())) { + if (store.getMemref() == symVal) { + mlir::OpBuilder::InsertPoint insertPtDel = + firOpBuilder.saveInsertionPoint(); + firOpBuilder.setInsertionPoint(add); + if (add.getLhs() == loadVal) { + firOpBuilder.create( + add.getLoc(), add.getRhs(), symVal); + } else { + firOpBuilder.create( + add.getLoc(), add.getRhs(), symVal); + } + store.erase(); + add.erase(); + load.erase(); + firOpBuilder.restoreInsertionPoint(insertPtDel); + } + } + } + } + } + } + } + } + } + } + } + } + } +} diff --git a/flang/test/Lower/OpenMP/reduction.f90 b/flang/test/Lower/OpenMP/reduction.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenMP/reduction.f90 @@ -0,0 +1,44 @@ +! RUN: bbc -emit-fir -fopenmp %s -o - | FileCheck %s +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +!CHECK-LABEL: omp.reduction.declare +!CHECK-SAME: @[[RED_NAME:.*]] : i32 init { +!CHECK: ^bb0(%{{.*}}: i32): +!CHECK: %[[C0_1:.*]] = arith.constant 0 : i32 +!CHECK: omp.yield(%[[C0_1]] : i32) +!CHECK: } combiner { +!CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32): +!CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32 +!CHECK: omp.yield(%[[RES]] : i32) +!CHECK: } +!CHECK-LABEL: func.func @_QPsimple_reduction +!CHECK: %[[XREF:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_reductionEx"} +!CHECK: %[[C0_2:.*]] = arith.constant 0 : i32 +!CHECK: fir.store %[[C0_2]] to %[[XREF]] : !fir.ref +!CHECK: omp.parallel +!CHECK: %[[I_PVT_REF:.*]] = fir.alloca i32 {adapt.valuebyref, pinned} +!CHECK: %[[C1_1:.*]] = arith.constant 1 : i32 +!CHECK: %[[C100:.*]] = arith.constant 100 : i32 +!CHECK: %[[C1_2:.*]] = arith.constant 1 : i32 +!CHECK: omp.wsloop reduction(@[[RED_NAME]] -> %[[XREF]] : !fir.ref) for (%[[IVAL:.*]]) : i32 = (%[[C1_1]]) to (%[[C100]]) inclusive step (%[[C1_2]]) +!CHECK: fir.store %[[IVAL]] to %[[I_PVT_REF]] : !fir.ref +!CHECK: %[[I_PVT_VAL:.*]] = fir.load %[[I_PVT_REF]] : !fir.ref +!CHECK: omp.reduction %[[I_PVT_VAL]], %[[XREF]] : !fir.ref +!CHECK: omp.yield +!CHECK: omp.terminator +!CHECK: %[[XVAL_FINAL:.*]] = fir.load %[[XREF]] : !fir.ref +!CHECK: {{.*}} = fir.call @_FortranAioOutputInteger32({{.*}}, %[[XVAL_FINAL]]) : (!fir.ref, i32) -> i1 +!CHECK: return + +subroutine simple_reduction + integer :: x + x = 0 + !$omp parallel + !$omp do reduction(+:x) + do i=1, 100 + x = x + i + end do + !$omp end do + !$omp end parallel + print *, x +end subroutine diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1308,7 +1308,8 @@ // 2.19.5.7 declare reduction Directive //===----------------------------------------------------------------------===// -def ReductionDeclareOp : OpenMP_Op<"reduction.declare", [Symbol]> { +def ReductionDeclareOp : OpenMP_Op<"reduction.declare", [Symbol, + IsolatedFromAbove]> { let summary = "declares a reduction kind"; let description = [{ diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -77,6 +77,21 @@ return success(); } }; + +struct ReductionOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(omp::ReductionOp curOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (curOp.accumulator().getType().isa()) { + // TODO: Support memref type in variable operands + return rewriter.notifyMatchFailure(curOp, "memref is not supported yet"); + } + rewriter.replaceOpWithNewOp( + curOp, TypeRange(), adaptor.getOperands(), curOp->getAttrs()); + return success(); + } +}; } // namespace void mlir::configureOpenMPToLLVMConversionLegality( @@ -95,14 +110,17 @@ return typeConverter.isLegal(op->getOperandTypes()) && typeConverter.isLegal(op->getResultTypes()); }); + target.addDynamicallyLegalOp([&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()); + }); } void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add< - RegionOpConversion, RegionOpConversion, - RegionOpConversion, RegionOpConversion, - RegionOpConversion, + ReductionOpConversion, RegionOpConversion, + RegionOpConversion, RegionOpConversion, + RegionOpConversion, RegionOpConversion, RegionLessOpWithVarOperandsConversion, RegionLessOpWithVarOperandsConversion, RegionLessOpWithVarOperandsConversion,