diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h --- a/flang/include/flang/Lower/OpenACC.h +++ b/flang/include/flang/Lower/OpenACC.h @@ -13,6 +13,8 @@ #ifndef FORTRAN_LOWER_OPENACC_H #define FORTRAN_LOWER_OPENACC_H +#include "mlir/Dialect/OpenACC/OpenACC.h" + namespace llvm { class StringRef; } @@ -21,9 +23,6 @@ class Location; class Type; class OpBuilder; -namespace acc { -class PrivateRecipeOp; -} } // namespace mlir namespace Fortran { @@ -57,6 +56,12 @@ llvm::StringRef, mlir::Location, mlir::Type); +/// Get a acc.reduction.recipe op for the given type or create it if it does not +/// exist yet. +mlir::acc::ReductionRecipeOp +createOrGetReductionRecipe(mlir::OpBuilder &, llvm::StringRef, mlir::Location, + mlir::Type, mlir::acc::ReductionOperator); + } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -22,7 +22,6 @@ #include "flang/Parser/parse-tree.h" #include "flang/Semantics/expression.h" #include "flang/Semantics/tools.h" -#include "mlir/Dialect/OpenACC/OpenACC.h" #include "llvm/Frontend/OpenACC/ACC.h.inc" // Special value for * passed in device_type or gang clauses. @@ -526,6 +525,132 @@ } } +/// Return the corresponding enum value for the mlir::acc::ReductionOperator +/// from the parser representation. +static mlir::acc::ReductionOperator +getReductionOperator(const Fortran::parser::AccReductionOperator &op) { + switch (op.v) { + case Fortran::parser::AccReductionOperator::Operator::Plus: + return mlir::acc::ReductionOperator::AccAdd; + case Fortran::parser::AccReductionOperator::Operator::Multiply: + return mlir::acc::ReductionOperator::AccMul; + case Fortran::parser::AccReductionOperator::Operator::Max: + return mlir::acc::ReductionOperator::AccMax; + case Fortran::parser::AccReductionOperator::Operator::Min: + return mlir::acc::ReductionOperator::AccMin; + case Fortran::parser::AccReductionOperator::Operator::Iand: + return mlir::acc::ReductionOperator::AccIand; + case Fortran::parser::AccReductionOperator::Operator::Ior: + return mlir::acc::ReductionOperator::AccIor; + case Fortran::parser::AccReductionOperator::Operator::Ieor: + return mlir::acc::ReductionOperator::AccXor; + case Fortran::parser::AccReductionOperator::Operator::And: + return mlir::acc::ReductionOperator::AccLand; + case Fortran::parser::AccReductionOperator::Operator::Or: + return mlir::acc::ReductionOperator::AccLor; + case Fortran::parser::AccReductionOperator::Operator::Eqv: + return mlir::acc::ReductionOperator::AccEqv; + case Fortran::parser::AccReductionOperator::Operator::Neqv: + return mlir::acc::ReductionOperator::AccNeqv; + } + llvm_unreachable("unexpected reduction operator"); +} + +static mlir::Value genReductionInitValue(mlir::OpBuilder &builder, + mlir::Location loc, mlir::Type ty, + mlir::acc::ReductionOperator op) { + if (op != mlir::acc::ReductionOperator::AccAdd) + TODO(loc, "reduction operator"); + + unsigned initValue = 0; + + if (ty.isIntOrIndex()) + return builder.create( + loc, ty, builder.getIntegerAttr(ty, initValue)); + if (mlir::isa(ty)) + return builder.create( + loc, ty, builder.getFloatAttr(ty, initValue)); + TODO(loc, "reduction type"); +} + +static mlir::Value genCombiner(mlir::OpBuilder &builder, mlir::Location loc, + mlir::acc::ReductionOperator op, mlir::Type ty, + mlir::Value value1, mlir::Value value2) { + if (op == mlir::acc::ReductionOperator::AccAdd) { + if (ty.isIntOrIndex()) + return builder.create(loc, value1, value2); + if (mlir::isa(ty)) + return builder.create(loc, value1, value2); + TODO(loc, "reduction add type"); + } + TODO(loc, "reduction operator"); +} + +mlir::acc::ReductionRecipeOp Fortran::lower::createOrGetReductionRecipe( + mlir::OpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc, + mlir::Type ty, mlir::acc::ReductionOperator op) { + mlir::ModuleOp mod = + builder.getBlock()->getParent()->getParentOfType(); + if (auto recipe = mod.lookupSymbol(recipeName)) + return recipe; + + auto crtPos = builder.saveInsertionPoint(); + mlir::OpBuilder modBuilder(mod.getBodyRegion()); + auto recipe = + modBuilder.create(loc, recipeName, ty, op); + builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(), + {ty}, {loc}); + builder.setInsertionPointToEnd(&recipe.getInitRegion().back()); + mlir::Value initValue = genReductionInitValue(builder, loc, ty, op); + builder.create(loc, initValue); + + builder.createBlock(&recipe.getCombinerRegion(), + recipe.getCombinerRegion().end(), {ty, ty}, {loc, loc}); + builder.setInsertionPointToEnd(&recipe.getCombinerRegion().back()); + mlir::Value v1 = recipe.getCombinerRegion().front().getArgument(0); + mlir::Value v2 = recipe.getCombinerRegion().front().getArgument(1); + mlir::Value combinedValue = genCombiner(builder, loc, op, ty, v1, v2); + builder.create(loc, combinedValue); + builder.restoreInsertionPoint(crtPos); + return recipe; +} + +static void +genReductions(const Fortran::parser::AccObjectListWithReduction &objectList, + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::lower::StatementContext &stmtCtx, + llvm::SmallVectorImpl &reductionOperands, + llvm::SmallVector &reductionRecipes) { + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + const auto &objects = std::get(objectList.t); + const auto &op = + std::get(objectList.t); + mlir::acc::ReductionOperator mlirOp = getReductionOperator(op); + for (const auto &accObject : objects.v) { + llvm::SmallVector bounds; + std::stringstream asFortran; + mlir::Location operandLocation = genOperandLocation(converter, accObject); + mlir::Value baseAddr = gatherDataOperandAddrAndBounds( + converter, builder, semanticsContext, stmtCtx, accObject, + operandLocation, asFortran, bounds); + + if (!fir::isa_trivial(fir::unwrapRefType(baseAddr.getType()))) + TODO(operandLocation, "reduction with unsupported type"); + + mlir::Type ty = fir::unwrapRefType(baseAddr.getType()); + std::string recipeName = fir::getTypeAsString( + ty, converter.getKindMap(), + ("reduction_" + stringifyReductionOperator(mlirOp)).str()); + mlir::acc::ReductionRecipeOp recipe = + Fortran::lower::createOrGetReductionRecipe(builder, recipeName, + operandLocation, ty, mlirOp); + reductionRecipes.push_back(mlir::SymbolRefAttr::get( + builder.getContext(), recipe.getSymName().str())); + reductionOperands.push_back(baseAddr); + } +} + static void addOperands(llvm::SmallVectorImpl &operands, llvm::SmallVectorImpl &operandSegments, @@ -666,7 +791,7 @@ mlir::Value gangStatic; llvm::SmallVector tileOperands, privateOperands, reductionOperands; - llvm::SmallVector privatizations; + llvm::SmallVector privatizations, reductionRecipes; bool hasGang = false, hasVector = false, hasWorker = false; for (const Fortran::parser::AccClause &clause : accClauseList.v) { @@ -735,10 +860,11 @@ &clause.u)) { genPrivatizations(privateClause->v, converter, semanticsContext, stmtCtx, privateOperands, privatizations); - } else if (std::get_if(&clause.u)) { - // Reduction clause is left out for the moment as the clause will probably - // end up having its own operation. - TODO(clauseLocation, "OpenACC compute construct reduction lowering"); + } else if (const auto *reductionClause = + std::get_if( + &clause.u)) { + genReductions(reductionClause->v, converter, semanticsContext, stmtCtx, + reductionOperands, reductionRecipes); } } @@ -767,6 +893,10 @@ loopOp.setPrivatizationsAttr( mlir::ArrayAttr::get(builder.getContext(), privatizations)); + if (!reductionRecipes.empty()) + loopOp.setReductionRecipesAttr( + mlir::ArrayAttr::get(builder.getContext(), reductionRecipes)); + // Lower clauses mapped to attributes for (const Fortran::parser::AccClause &clause : accClauseList.v) { if (const auto *collapseClause = diff --git a/flang/test/Lower/OpenACC/acc-reduction.f90 b/flang/test/Lower/OpenACC/acc-reduction.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/OpenACC/acc-reduction.f90 @@ -0,0 +1,51 @@ +! This test checks lowering of OpenACC reduction clause. + +! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s + +! CHECK-LABEL: acc.reduction.recipe @reduction_add_f32 : f32 reduction_operator init { +! CHECK: ^bb0(%{{.*}}: f32): +! CHECK: %[[INIT:.*]] = arith.constant 0.000000e+00 : f32 +! CHECK: acc.yield %[[INIT]] : f32 +! CHECK: } combiner { +! CHECK: ^bb0(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32): +! CHECK: %[[COMBINED:.*]] = arith.addf %[[ARG0]], %[[ARG1]] {{.*}} : f32 +! CHECK: acc.yield %[[COMBINED]] : f32 +! CHECK: } + +! CHECK-LABEL: acc.reduction.recipe @reduction_add_i32 : i32 reduction_operator init { +! CHECK: ^bb0(%{{.*}}: i32): +! CHECK: %[[INIT:.*]] = arith.constant 0 : i32 +! CHECK: acc.yield %[[INIT]] : i32 +! CHECK: } combiner { +! CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32): +! CHECK: %[[COMBINED:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32 +! CHECK: acc.yield %[[COMBINED]] : i32 +! CHECK: } + +subroutine acc_reduction_add_int(a, b) + integer :: a(100) + integer :: i, b + + !$acc loop reduction(+:b) + do i = 1, 100 + b = b + a(i) + end do +end subroutine + +! CHECK-LABEL: func.func @_QPacc_reduction_add_int( +! CHECK-SAME: %{{.*}}: !fir.ref> {fir.bindc_name = "a"}, %[[B:.*]]: !fir.ref {fir.bindc_name = "b"}) +! CHECK: acc.loop reduction(@reduction_add_i32 -> %[[B]] : !fir.ref) { + +subroutine acc_reduction_add_float(a, b) + real :: a(100), b + integer :: i + + !$acc loop reduction(+:b) + do i = 1, 100 + b = b + a(i) + end do +end subroutine + +! CHECK-LABEL: func.func @_QPacc_reduction_add_float( +! CHECK-SAME: %{{.*}}: !fir.ref> {fir.bindc_name = "a"}, %[[B:.*]]: !fir.ref {fir.bindc_name = "b"}) +! CHECK: acc.loop reduction(@reduction_add_f32 -> %[[B]] : !fir.ref) diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -498,7 +498,7 @@ static LogicalResult checkSymOperandList(Operation *op, std::optional attributes, mlir::OperandRange operands, llvm::StringRef operandName, - llvm::StringRef symbolName) { + llvm::StringRef symbolName, bool checkOperandType = true) { if (!operands.empty()) { if (!attributes || attributes->size() != operands.size()) return op->emitOpError() @@ -527,7 +527,7 @@ << "expected symbol reference " << symbolRef << " to point to a " << operandName << " declaration"; - if (decl.getType() && decl.getType() != varType) + if (checkOperandType && decl.getType() && decl.getType() != varType) return op->emitOpError() << "expected " << operandName << " (" << varType << ") to be the same type as " << operandName << " declaration (" << decl.getType() << ")"; @@ -751,7 +751,7 @@ if (failed(checkSymOperandList( *this, getReductionRecipes(), getReductionOperands(), "reduction", - "reductions"))) + "reductions", false))) return failure(); // Check non-empty body().