diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp --- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp @@ -168,6 +168,30 @@ using HlfirIntrinsicConversion::HlfirIntrinsicConversion; using IntrinsicArgument = typename HlfirIntrinsicConversion::IntrinsicArgument; + using HlfirIntrinsicConversion::lowerArguments; + using HlfirIntrinsicConversion::processReturnValue; + +protected: + auto buildNumericalArgs(OP operation, mlir::Type i32, mlir::Type logicalType, + mlir::PatternRewriter &rewriter, + std::string opName) const { + llvm::SmallVector inArgs; + inArgs.push_back({operation.getArray(), operation.getArray().getType()}); + inArgs.push_back({operation.getDim(), i32}); + inArgs.push_back({operation.getMask(), logicalType}); + auto *argLowering = fir::getIntrinsicArgumentLowering(opName); + return lowerArguments(operation, inArgs, rewriter, argLowering); + }; + + auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType, + mlir::PatternRewriter &rewriter, + std::string opName) const { + llvm::SmallVector inArgs; + inArgs.push_back({operation.getMask(), logicalType}); + inArgs.push_back({operation.getDim(), i32}); + auto *argLowering = fir::getIntrinsicArgumentLowering(opName); + return lowerArguments(operation, inArgs, rewriter, argLowering); + }; public: mlir::LogicalResult @@ -178,9 +202,14 @@ opName = "sum"; } else if constexpr (std::is_same_v) { opName = "product"; + } else if constexpr (std::is_same_v) { + opName = "any"; + } else if constexpr (std::is_same_v) { + opName = "all"; } else { return mlir::failure(); } + fir::KindMapping kindMapping{rewriter.getContext()}; fir::FirOpBuilder builder{rewriter, kindMapping}; const mlir::Location &loc = operation->getLoc(); @@ -188,14 +217,15 @@ mlir::Type i32 = builder.getI32Type(); mlir::Type logicalType = fir::LogicalType::get( builder.getContext(), builder.getKindMap().defaultLogicalKind()); - llvm::SmallVector inArgs; - inArgs.push_back({operation.getArray(), operation.getArray().getType()}); - inArgs.push_back({operation.getDim(), i32}); - inArgs.push_back({operation.getMask(), logicalType}); - auto *argLowering = fir::getIntrinsicArgumentLowering(opName); - llvm::SmallVector args = - this->lowerArguments(operation, inArgs, rewriter, argLowering); + llvm::SmallVector args; + + if constexpr (std::is_same_v || + std::is_same_v) { + args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName); + } else { + args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName); + } mlir::Type scalarResultType = hlfir::getFortranElementType(operation.getType()); @@ -203,8 +233,7 @@ auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(builder, loc, opName, scalarResultType, args); - this->processReturnValue(operation, resultExv, mustBeFreed, builder, - rewriter); + processReturnValue(operation, resultExv, mustBeFreed, builder, rewriter); return mlir::success(); } }; @@ -213,37 +242,9 @@ using ProductOpConversion = HlfirReductionIntrinsicConversion; -struct AnyOpConversion : public HlfirIntrinsicConversion { - using HlfirIntrinsicConversion::HlfirIntrinsicConversion; +using AnyOpConversion = HlfirReductionIntrinsicConversion; - mlir::LogicalResult - matchAndRewrite(hlfir::AnyOp any, - mlir::PatternRewriter &rewriter) const override { - fir::KindMapping kindMapping{rewriter.getContext()}; - fir::FirOpBuilder builder{rewriter, kindMapping}; - const mlir::Location &loc = any->getLoc(); - - mlir::Type i32 = builder.getI32Type(); - mlir::Type logicalType = fir::LogicalType::get( - builder.getContext(), builder.getKindMap().defaultLogicalKind()); - llvm::SmallVector inArgs; - inArgs.push_back({any.getMask(), logicalType}); - inArgs.push_back({any.getDim(), i32}); - - auto *argLowering = fir::getIntrinsicArgumentLowering("any"); - llvm::SmallVector args = - this->lowerArguments(any, inArgs, rewriter, argLowering); - - mlir::Type resultType = hlfir::getFortranElementType(any.getType()); - - auto [resultExv, mustBeFreed] = - fir::genIntrinsicCall(builder, loc, "any", resultType, args); - - this->processReturnValue(any, resultExv, mustBeFreed, builder, rewriter); - - return mlir::success(); - } -}; +using AllOpConversion = HlfirReductionIntrinsicConversion; struct MatmulOpConversion : public HlfirIntrinsicConversion { using HlfirIntrinsicConversion::HlfirIntrinsicConversion; @@ -354,14 +355,15 @@ mlir::MLIRContext *context = &getContext(); mlir::RewritePatternSet patterns(context); patterns.insert(context); + AllOpConversion, AnyOpConversion, SumOpConversion, + ProductOpConversion, TransposeOpConversion>(context); mlir::ConversionTarget target(*context); target.addLegalDialect(); target.addIllegalOp(); + hlfir::ProductOp, hlfir::TransposeOp, hlfir::AnyOp, + hlfir::AllOp>(); target.markUnknownOpDynamicallyLegal( [](mlir::Operation *) { return true; }); if (mlir::failed( diff --git a/flang/test/HLFIR/all-lowering.fir b/flang/test/HLFIR/all-lowering.fir new file mode 100644 --- /dev/null +++ b/flang/test/HLFIR/all-lowering.fir @@ -0,0 +1,157 @@ +// Test hlfir.all operation lowering to fir runtime call +// RUN: fir-opt %s -lower-hlfir-intrinsics | FileCheck %s + +func.func @_QPall1(%arg0: !fir.box>> {fir.bindc_name = "a"}, %arg1: !fir.ref> {fir.bindc_name = "s"}) { + %0:2 = hlfir.declare %arg0 {uniq_name = "_QFall1Ea"} : (!fir.box>>) -> (!fir.box>>, !fir.box>>) + %1:2 = hlfir.declare %arg1 {uniq_name = "_QFall1Es"} : (!fir.ref>) -> (!fir.ref>, !fir.ref>) + %2 = hlfir.all %0#0 : (!fir.box>>) -> !fir.logical<4> + hlfir.assign %2 to %1#0 : !fir.logical<4>, !fir.ref> + return +} +// CHECK-LABEL: func.func @_QPall1( +// CHECK: %[[ARG0:.*]]: !fir.box>> {fir.bindc_name = "a"} +// CHECK: %[[ARG1:.*]]: !fir.ref> +// CHECK-DAG: %[[MASK:.*]]:2 = hlfir.declare %[[ARG0]] +// CHECK-DAG: %[[RES:.*]]:2 = hlfir.declare %[[ARG1]] +// CHECK-DAG: %[[MASK_ARG:.*]] = fir.convert %[[MASK]]#1 : (!fir.box>>) -> !fir.box +// CHECK: %[[RET_ARG:.*]] = fir.call @_FortranAAll(%[[MASK_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[C1:.*]]) : (!fir.box, !fir.ref, i32, i32) -> i1 +// CHECK-NEXT: %[[RET:.*]] = fir.convert %[[RET_ARG]] : (i1) -> !fir.logical<4> +// CHECK-NEXT: hlfir.assign %[[RET]] to %[[RES]]#0 : !fir.logical<4>, !fir.ref> +// CHECK-NEXT: return +// CHECK-NEXT: } + +func.func @_QPall2(%arg0: !fir.box>> {fir.bindc_name = "a"}, %arg1: !fir.box>> {fir.bindc_name = "s"}, %arg2: !fir.ref {fir.bindc_name = "d"}) { + %0:2 = hlfir.declare %arg0 {uniq_name = "_QFall2Ea"} : (!fir.box>>) -> (!fir.box>>, !fir.box>>) + %1:2 = hlfir.declare %arg2 {uniq_name = "_QFall2Ed"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %2:2 = hlfir.declare %arg1 {uniq_name = "_QFall2Es"} : (!fir.box>>) -> (!fir.box>>, !fir.box>>) + %3 = fir.load %1#0 : !fir.ref + %4 = hlfir.all %0#0 dim %3 : (!fir.box>>, i32) -> !hlfir.expr> + hlfir.assign %4 to %2#0 : !hlfir.expr>, !fir.box>> + hlfir.destroy %4 : !hlfir.expr> + return +} +// CHECK-LABEL: func.func @_QPall2( +// CHECK: %[[ARG0:.*]]: !fir.box>> +// CHECK: %[[ARG1:.*]]: !fir.box>> +// CHECK: %[[ARG2:.*]]: !fir.ref +// CHECK-DAG: %[[MASK:.*]]:2 = hlfir.declare %[[ARG0]] +// CHECK-DAG: %[[DIM_VAR:.*]]:2 = hlfir.declare %[[ARG2]] +// CHECK-DAG: %[[RES:.*]]:2 = hlfir.declare %[[ARG1]] + +// CHECK-DAG: %[[RET_BOX:.*]] = fir.alloca !fir.box>>> +// CHECK-DAG: %[[RET_ADDR:.*]] = fir.zero_bits !fir.heap>> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[RET_SHAPE:.*]] = fir.shape %[[C0]] : (index) -> !fir.shape<1> +// CHECK-DAG: %[[RET_EMBOX:.*]] = fir.embox %[[RET_ADDR]](%[[RET_SHAPE]]) +// CHECK-DAG: fir.store %[[RET_EMBOX]] to %[[RET_BOX]] + +// CHECK-DAG: %[[DIM:.*]] = fir.load %[[DIM_VAR]]#0 : !fir.ref +// CHECK-DAG: %[[RET_ARG:.*]] = fir.convert %[[RET_BOX]] +// CHECK-DAG: %[[MASK_ARG:.*]] = fir.convert %[[MASK]]#1 + +// CHECK: %[[NONE:.*]] = fir.call @_FortranAAllDim(%[[RET_ARG]], %[[MASK_ARG]], %[[DIM]], %[[LOC_STR:.*]], %[[LOC_N:.*]]) : (!fir.ref>, !fir.box, i32, !fir.ref, i32) -> none +// CHECK: %[[RET:.*]] = fir.load %[[RET_BOX]] +// CHECK: %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[RET]] +// CHECK-NEXT: %[[ADDR:.*]] = fir.box_addr %[[RET]] +// CHECK-NEXT: %[[SHIFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1 +// CHECK-NEXT: %[[TMP:.*]]:2 = hlfir.declare %[[ADDR]](%[[SHIFT]]) {uniq_name = ".tmp.intrinsic_result"} +// CHECK: %[[TRUE:.*]] = arith.constant true +// CHECK: %[[EXPR:.*]] = hlfir.as_expr %[[TMP]]#0 move %[[TRUE]] : (!fir.box>>, i1) -> !hlfir.expr> +// CHECK: hlfir.assign %[[EXPR]] to %[[RES]]#0 +// CHECK: hlfir.destroy %[[EXPR]] +// CHECK-NEXT: return +// CHECK-NEXT: } + +func.func @_QPall3(%arg0: !fir.ref>> {fir.bindc_name = "s"}) { + %0 = fir.address_of(@_QFall3Ea) : !fir.ref>> + %c2 = arith.constant 2 : index + %c2_0 = arith.constant 2 : index + %1 = fir.shape %c2, %c2_0 : (index, index) -> !fir.shape<2> + %2:2 = hlfir.declare %0(%1) {uniq_name = "_QFall3Ea"} : (!fir.ref>>, !fir.shape<2>) -> (!fir.ref>>, !fir.ref>>) + %c2_1 = arith.constant 2 : index + %3 = fir.shape %c2_1 : (index) -> !fir.shape<1> + %4:2 = hlfir.declare %arg0(%3) {uniq_name = "_QFall3Es"} : (!fir.ref>>, !fir.shape<1>) -> (!fir.ref>>, !fir.ref>>) + %c1_i32 = arith.constant 1 : i32 + %5 = hlfir.all %2#0 dim %c1_i32 : (!fir.ref>>, i32) -> !hlfir.expr<2x!fir.logical<4>> + hlfir.assign %5 to %4#0 : !hlfir.expr<2x!fir.logical<4>>, !fir.ref>> + hlfir.destroy %5 : !hlfir.expr<2x!fir.logical<4>> + return +} +// CHECK-LABEL: func.func @_QPall3( +// CHECK: %[[ARG0:.*]]: !fir.ref>> +// CHECK-DAG: %[[RET_BOX:.*]] = fir.alloca !fir.box>>> +// CHECK-DAG: %[[RET_ADDR:.*]] = fir.zero_bits !fir.heap>> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[RET_SHAPE:.*]] = fir.shape %[[C0]] : (index) -> !fir.shape<1> +// CHECK-DAG: %[[RET_EMBOX:.*]] = fir.embox %[[RET_ADDR]](%[[RET_SHAPE]]) +// CHECK-DAG: fir.store %[[RET_EMBOX]] to %[[RET_BOX]] +// CHECK-DAG: %[[RES:.*]]:2 = hlfir.declare %[[ARG0]](%[[RES_SHAPE:.*]]) + +// CHECK-DAG: %[[MASK_ADDR:.*]] = fir.address_of +// CHECK-DAG: %[[MASK_VAR:.*]]:2 = hlfir.declare %[[MASK_ADDR]](%[[MASK_SHAPE:.*]]) +// CHECK-DAG: %[[MASK_BOX:.*]] = fir.embox %[[MASK_VAR]]#1(%[[MASK_SHAPE:.*]]) + +// CHECK-DAG: %[[DIM:.*]] = arith.constant 1 : i32 + +// CHECK-DAG: %[[RET_ARG:.*]] = fir.convert %[[RET_BOX]] +// CHECK-DAG: %[[MASK_ARG:.*]] = fir.convert %[[MASK_BOX]] : (!fir.box>>) -> !fir.box +// CHECK: %[[NONE:.*]] = fir.call @_FortranAAllDim(%[[RET_ARG]], %[[MASK_ARG]], %[[DIM]], %[[LOC_STR:.*]], %[[LOC_N:.*]]) +// CHECK: %[[RET:.*]] = fir.load %[[RET_BOX]] +// CHECK: %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[RET]] +// CHECK-NEXT: %[[ADDR:.*]] = fir.box_addr %[[RET]] +// CHECK-NEXT: %[[SHIFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1 +// CHECK-NEXT: %[[TMP:.*]]:2 = hlfir.declare %[[ADDR]](%[[SHIFT]]) {uniq_name = ".tmp.intrinsic_result"} +// CHECK: %[[TRUE:.*]] = arith.constant true +// CHECK: %[[EXPR:.*]] = hlfir.as_expr %[[TMP]]#0 move %[[TRUE]] : (!fir.box>>, i1) -> !hlfir.expr> +// CHECK: hlfir.assign %[[EXPR]] to %[[RES]] +// CHECK: hlfir.destroy %[[EXPR]] +// CHECK-NEXT: return +// CHECK-NEXT: } + +func.func @_QPall4(%arg0: !fir.box>> {fir.bindc_name = "a"}, %arg1: !fir.box>> {fir.bindc_name = "s"}, %arg2: !fir.ref>> {fir.bindc_name = "d"}) { + %0:2 = hlfir.declare %arg0 {uniq_name = "_QFall4Ea"} : (!fir.box>>) -> (!fir.box>>, !fir.box>>) + %1:2 = hlfir.declare %arg2 {fortran_attrs = #fir.var_attrs, uniq_name = "_QFall4Ed"} : (!fir.ref>>) -> (!fir.ref>>, !fir.ref>>) + %2:2 = hlfir.declare %arg1 {uniq_name = "_QFall4Es"} : (!fir.box>>) -> (!fir.box>>, !fir.box>>) + %3 = fir.load %1#0 : !fir.ref>> + %4 = fir.box_addr %3 : (!fir.box>) -> !fir.ptr + %5 = fir.load %4 : !fir.ptr + %6 = hlfir.no_reassoc %5 : i32 + %7 = hlfir.all %0#0 dim %6 : (!fir.box>>, i32) -> !hlfir.expr> + hlfir.assign %7 to %2#0 : !hlfir.expr>, !fir.box>> + hlfir.destroy %7 : !hlfir.expr> + return +} +// CHECK-LABEL: func.func @_QPall4( +// CHECK: %[[ARG0:.*]]: !fir.box>> +// CHECK: %[[ARG1:.*]]: !fir.box>> +// CHECK: %[[ARG2:.*]]: !fir.ref>> +// CHECK-DAG: %[[MASK:.*]]:2 = hlfir.declare %[[ARG0]] +// CHECK-DAG: %[[DIM_ARG:.*]]:2 = hlfir.declare %[[ARG2]] +// CHECK-DAG: %[[RES:.*]]:2 = hlfir.declare %[[ARG1]] + +// CHECK-DAG: %[[RET_BOX:.*]] = fir.alloca !fir.box>>> +// CHECK-DAG: %[[DIM_PTR:.*]] = fir.load %[[DIM_ARG]]#0 : !fir.ref>> +// CHECK-DAG: %[[DIM_ADDR:.*]] = fir.box_addr %[[DIM_PTR]] +// CHECK-DAG: %[[DIM_VAR:.*]] = fir.load %[[DIM_ADDR]] +// CHECK-DAG: %[[DIM:.*]] = hlfir.no_reassoc %[[DIM_VAR]] + +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[RET_ADDR:.*]] = fir.zero_bits !fir.heap>> +// CHECK-DAG: %[[RET_SHAPE:.*]] = fir.shape %[[C0]] : (index) -> !fir.shape<1> +// CHECK-DAG: %[[RET_EMBOX:.*]] = fir.embox %[[RET_ADDR]](%[[RET_SHAPE]]) +// CHECK-DAG: fir.store %[[RET_EMBOX]] to %[[RET_BOX]] +// CHECK-DAG: %[[RET_ARG:.*]] = fir.convert %[[RET_BOX]] +// CHECK-DAG: %[[MASK_ARG:.*]] = fir.convert %[[MASK]]#1 + +// CHECK: %[[NONE:.*]] = fir.call @_FortranAAllDim(%[[RET_ARG]], %[[MASK_ARG]], %[[DIM]], %[[LOC_STR:.*]], %[[LOC_N:.*]]) : (!fir.ref>, !fir.box, i32, !fir.ref, i32) -> none +// CHECK: %[[RET:.*]] = fir.load %[[RET_BOX]] +// CHECK: %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[RET]] +// CHECK-NEXT: %[[ADDR:.*]] = fir.box_addr %[[RET]] +// CHECK-NEXT: %[[SHIFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1 +// CHECK-NEXT: %[[TMP:.*]]:2 = hlfir.declare %[[ADDR]](%[[SHIFT]]) {uniq_name = ".tmp.intrinsic_result"} +// CHECK: %[[TRUE:.*]] = arith.constant true +// CHECK: %[[EXPR:.*]] = hlfir.as_expr %[[TMP]]#0 move %[[TRUE]] : (!fir.box>>, i1) -> !hlfir.expr> +// CHECK: hlfir.assign %[[EXPR]] to %[[RES]] +// CHECK: hlfir.destroy %[[EXPR]] +// CHECK-NEXT: return +// CHECK-NEXT: } \ No newline at end of file