diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -851,7 +851,7 @@ if (reductionOpName.contains("add")) return 0; else if (reductionOpName.contains("multiply") || - reductionOpName.contains("and")) + reductionOpName.contains("and") || reductionOpName.contains("eqv")) return 1; TODO(loc, "Reduction of some intrinsic operators is not supported"); } @@ -859,14 +859,19 @@ static Value getReductionInitValue(mlir::Location loc, mlir::Type type, llvm::StringRef reductionOpName, fir::FirOpBuilder &builder) { - assert(type.isIntOrIndexOrFloat() && - "only integer and float types are currently supported"); if (type.isa()) return builder.create( loc, type, builder.getFloatAttr( type, (double)getOperationIdentity(reductionOpName, loc))); + if (type.isa()) { + Value intConst = builder.create( + loc, builder.getI1Type(), + builder.getIntegerAttr(builder.getI1Type(), + getOperationIdentity(reductionOpName, loc))); + return builder.createConvert(loc, type, intConst); + } return builder.create( loc, type, builder.getIntegerAttr(type, getOperationIdentity(reductionOpName, loc))); @@ -925,9 +930,25 @@ getReductionOperation( builder, type, loc, op1, op2); break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: - reductionOp = builder.create(loc, op1, op2); + case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: { + Value op1_i1 = builder.createConvert(loc, builder.getI1Type(), op1); + Value op2_i1 = builder.createConvert(loc, builder.getI1Type(), op2); + + Value andiOp = builder.create(loc, op1_i1, op2_i1); + + reductionOp = builder.createConvert(loc, type, andiOp); + break; + } + case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: { + Value op1_i1 = builder.createConvert(loc, builder.getI1Type(), op1); + Value op2_i1 = builder.createConvert(loc, builder.getI1Type(), op2); + + Value cmpiOp = builder.create( + loc, arith::CmpIPredicate::eq, op1_i1, op2_i1); + + reductionOp = builder.createConvert(loc, type, cmpiOp); break; + } default: TODO(loc, "Reduction of some intrinsic operators is not supported"); } @@ -1014,6 +1035,8 @@ break; case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: return "and_reduction"; + case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + return "eqv_reduction"; default: reductionName = "other_reduction"; break; @@ -1123,6 +1146,7 @@ case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: break; default: @@ -1139,8 +1163,11 @@ symVal.getType().cast().getEleTy(); reductionVars.push_back(symVal); if (redType.isa()) - redType = firOpBuilder.getI1Type(); - if (redType.isIntOrIndexOrFloat()) { + decl = createReductionDecl( + firOpBuilder, + getReductionName(intrinsicOp, firOpBuilder.getI1Type()), + intrinsicOp, redType, currentLocation); + else if (redType.isIntOrIndexOrFloat()) { decl = createReductionDecl( firOpBuilder, getReductionName(intrinsicOp, redType), intrinsicOp, redType, currentLocation); @@ -1824,6 +1851,7 @@ case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: break; default: continue; @@ -1835,9 +1863,7 @@ mlir::Value reductionVal = converter.getSymbolAddress(*symbol); mlir::Type reductionType = reductionVal.getType().cast().getEleTy(); - - if (intrinsicOp != - Fortran::parser::DefinedOperator::IntrinsicOperator::AND) { + if (!reductionType.isa()) { if (!reductionType.isIntOrIndexOrFloat()) continue; } @@ -1845,8 +1871,7 @@ if (auto loadOp = mlir::dyn_cast( reductionValUse.getOwner())) { mlir::Value loadVal = loadOp.getRes(); - if (intrinsicOp == Fortran::parser::DefinedOperator:: - IntrinsicOperator::AND) { + if (reductionType.isa()) { mlir::Operation *reductionOp = findReductionChain(loadVal); fir::ConvertOp convertOp = getConvertFromReductionOp(reductionOp, loadVal); diff --git a/flang/test/Lower/OpenMP/Todo/reduction-eqv.f90 b/flang/test/Lower/OpenMP/Todo/reduction-eqv.f90 deleted file mode 100644 --- a/flang/test/Lower/OpenMP/Todo/reduction-eqv.f90 +++ /dev/null @@ -1,15 +0,0 @@ -! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s -! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s - -! CHECK: not yet implemented: Reduction of some intrinsic operators is not supported -subroutine reduction_eqv(y) - logical :: x, y(100) - !$omp parallel - !$omp do reduction(.eqv.:x) - do i=1, 100 - x = x .eqv. y(i) - end do - !$omp end do - !$omp end parallel - print *, x -end subroutine diff --git a/flang/test/Lower/OpenMP/wsloop-reduction-logical-and.f90 b/flang/test/Lower/OpenMP/wsloop-reduction-logical-and.f90 --- a/flang/test/Lower/OpenMP/wsloop-reduction-logical-and.f90 +++ b/flang/test/Lower/OpenMP/wsloop-reduction-logical-and.f90 @@ -2,14 +2,18 @@ ! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s !CHECK-LABEL: omp.reduction.declare -!CHECK-SAME: @[[RED_NAME:.*]] : i1 init { -!CHECK: ^bb0(%{{.*}}: i1): +!CHECK-SAME: @[[RED_NAME:.*]] : !fir.logical<4> init { +!CHECK: ^bb0(%{{.*}}: !fir.logical<4>): !CHECK: %true = arith.constant true -!CHECK: omp.yield(%true : i1) +!CHECK: %[[true_fir:.*]] = fir.convert %true : (i1) -> !fir.logical<4> +!CHECK: omp.yield(%[[true_fir]] : !fir.logical<4>) !CHECK: } combiner { -!CHECK: ^bb0(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1): -!CHECK: %[[RES:.*]] = arith.andi %[[ARG0]], %[[ARG1]] : i1 -!CHECK: omp.yield(%[[RES]] : i1) +!CHECK: ^bb0(%[[ARG0:.*]]: !fir.logical<4>, %[[ARG1:.*]]: !fir.logical<4>): +!CHECK: %[[arg0_i1:.*]] = fir.convert %[[ARG0]] : (!fir.logical<4>) -> i1 +!CHECK: %[[arg1_i1:.*]] = fir.convert %[[ARG1]] : (!fir.logical<4>) -> i1 +!CHECK: %[[RES:.*]] = arith.andi %[[arg0_i1]], %[[arg1_i1]] : i1 +!CHECK: %[[RES_logical:.*]] = fir.convert %[[RES]] : (i1) -> !fir.logical<4> +!CHECK: omp.yield(%[[RES_logical]] : !fir.logical<4>) !CHECK: } !CHECK-LABEL: func.func @_QPsimple_reduction( diff --git a/flang/test/Lower/OpenMP/wsloop-reduction-logical-and.f90 b/flang/test/Lower/OpenMP/wsloop-reduction-logical-eqv.f90 copy from flang/test/Lower/OpenMP/wsloop-reduction-logical-and.f90 copy to flang/test/Lower/OpenMP/wsloop-reduction-logical-eqv.f90 --- a/flang/test/Lower/OpenMP/wsloop-reduction-logical-and.f90 +++ b/flang/test/Lower/OpenMP/wsloop-reduction-logical-eqv.f90 @@ -2,14 +2,18 @@ ! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s !CHECK-LABEL: omp.reduction.declare -!CHECK-SAME: @[[RED_NAME:.*]] : i1 init { -!CHECK: ^bb0(%{{.*}}: i1): +!CHECK-SAME: @[[RED_NAME:.*]] : !fir.logical<4> init { +!CHECK: ^bb0(%{{.*}}: !fir.logical<4>): !CHECK: %true = arith.constant true -!CHECK: omp.yield(%true : i1) +!CHECK: %[[true_fir:.*]] = fir.convert %true : (i1) -> !fir.logical<4> +!CHECK: omp.yield(%[[true_fir]] : !fir.logical<4>) !CHECK: } combiner { -!CHECK: ^bb0(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1): -!CHECK: %[[RES:.*]] = arith.andi %[[ARG0]], %[[ARG1]] : i1 -!CHECK: omp.yield(%[[RES]] : i1) +!CHECK: ^bb0(%[[ARG0:.*]]: !fir.logical<4>, %[[ARG1:.*]]: !fir.logical<4>): +!CHECK: %[[arg0_i1:.*]] = fir.convert %[[ARG0]] : (!fir.logical<4>) -> i1 +!CHECK: %[[arg1_i1:.*]] = fir.convert %[[ARG1]] : (!fir.logical<4>) -> i1 +!CHECK: %[[RES:.*]] = arith.cmpi eq, %[[arg0_i1]], %[[arg1_i1]] : i1 +!CHECK: %[[RES_logical:.*]] = fir.convert %[[RES]] : (i1) -> !fir.logical<4> +!CHECK: omp.yield(%[[RES_logical]] : !fir.logical<4>) !CHECK: } !CHECK-LABEL: func.func @_QPsimple_reduction( @@ -37,9 +41,9 @@ logical :: x, y(100) x = .true. !$omp parallel - !$omp do reduction(.and.:x) + !$omp do reduction(.eqv.:x) do i=1, 100 - x = x .and. y(i) + x = x .eqv. y(i) end do !$omp end do !$omp end parallel @@ -70,9 +74,9 @@ logical :: x, y(100) x = .true. !$omp parallel - !$omp do reduction(.and.:x) + !$omp do reduction(.eqv.:x) do i=1, 100 - x = y(i) .and. x + x = y(i) .eqv. x end do !$omp end do !$omp end parallel @@ -121,13 +125,12 @@ y = .true. z = .true. !$omp parallel - !$omp do reduction(.and.:x,y,z) + !$omp do reduction(.eqv.:x,y,z) do i=1, 100 - x = x .and. w(i) - y = y .and. w(i) - z = z .and. w(i) + x = x .eqv. w(i) + y = y .eqv. w(i) + z = z .eqv. w(i) end do !$omp end do !$omp end parallel end subroutine - diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -83,6 +83,23 @@ } }; +template +struct RegionLessOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(T curOp, typename T::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); + SmallVector resTypes; + if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes))) + return failure(); + + rewriter.replaceOpWithNewOp(curOp, resTypes, adaptor.getOperands(), + curOp->getAttrs()); + return success(); + } +}; + struct ReductionOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult @@ -97,6 +114,29 @@ return success(); } }; + +struct ReductionDeclareOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(omp::ReductionDeclareOp curOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newOp = rewriter.create( + curOp.getLoc(), TypeRange(), curOp.sym_nameAttr(), + TypeAttr::get(this->getTypeConverter()->convertType( + curOp.typeAttr().getValue()))); + for (unsigned idx = 0; idx < curOp.getNumRegions(); idx++) { + rewriter.inlineRegionBefore(curOp.getRegion(idx), newOp.getRegion(idx), + newOp.getRegion(idx).end()); + if (failed(rewriter.convertRegionTypes(&newOp.getRegion(idx), + *this->getTypeConverter()))) + return failure(); + } + + rewriter.eraseOp(curOp); + return success(); + } +}; } // namespace void mlir::configureOpenMPToLLVMConversionLegality( @@ -109,30 +149,39 @@ typeConverter.isLegal(op->getOperandTypes()) && typeConverter.isLegal(op->getResultTypes()); }); - target - .addDynamicallyLegalOp( - [&](Operation *op) { - return typeConverter.isLegal(op->getOperandTypes()) && - typeConverter.isLegal(op->getResultTypes()); - }); + target.addDynamicallyLegalOp( + [&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); target.addDynamicallyLegalOp([&](Operation *op) { return typeConverter.isLegal(op->getOperandTypes()); }); + target.addDynamicallyLegalOp( + [&](Operation *op) { + return typeConverter.isLegal(&op->getRegion(0)) && + typeConverter.isLegal(&op->getRegion(1)) && + typeConverter.isLegal(&op->getRegion(2)) && + typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); } void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add< - ReductionOpConversion, RegionOpConversion, - RegionOpConversion, ReductionOpConversion, + ReductionOpConversion, ReductionDeclareOpConversion, + RegionOpConversion, RegionOpConversion, RegionOpConversion, RegionOpConversion, RegionOpConversion, RegionOpConversion, RegionOpConversion, RegionOpConversion, RegionLessOpWithVarOperandsConversion, RegionLessOpWithVarOperandsConversion, RegionLessOpWithVarOperandsConversion, - RegionLessOpWithVarOperandsConversion>(converter); + RegionLessOpWithVarOperandsConversion, + RegionLessOpConversion>(converter); } namespace {