diff --git a/flang/include/flang/Optimizer/HLFIR/Passes.h b/flang/include/flang/Optimizer/HLFIR/Passes.h --- a/flang/include/flang/Optimizer/HLFIR/Passes.h +++ b/flang/include/flang/Optimizer/HLFIR/Passes.h @@ -26,6 +26,7 @@ std::unique_ptr createBufferizeHLFIRPass(); std::unique_ptr createLowerHLFIRIntrinsicsPass(); std::unique_ptr createSimplifyHLFIRIntrinsicsPass(); +std::unique_ptr createInlineElementalsPass(); #define GEN_PASS_REGISTRATION #include "flang/Optimizer/HLFIR/Passes.h.inc" diff --git a/flang/include/flang/Optimizer/HLFIR/Passes.td b/flang/include/flang/Optimizer/HLFIR/Passes.td --- a/flang/include/flang/Optimizer/HLFIR/Passes.td +++ b/flang/include/flang/Optimizer/HLFIR/Passes.td @@ -30,4 +30,9 @@ let constructor = "hlfir::createSimplifyHLFIRIntrinsicsPass()"; } +def InlineElementals : Pass<"inline-elementals", "::mlir::func::FuncOp"> { + let summary = "Inline chained hlfir.elemental operations"; + let constructor = "hlfir::createInlineElementalsPass()"; +} + #endif //FORTRAN_DIALECT_HLFIR_PASSES diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc --- a/flang/include/flang/Tools/CLOptions.inc +++ b/flang/include/flang/Tools/CLOptions.inc @@ -242,6 +242,7 @@ addCanonicalizerPassWithoutRegionSimplification(pm); pm.addPass(hlfir::createSimplifyHLFIRIntrinsicsPass()); } + pm.addPass(hlfir::createInlineElementalsPass()); pm.addPass(hlfir::createLowerHLFIRIntrinsicsPass()); pm.addPass(hlfir::createBufferizeHLFIRPass()); pm.addPass(hlfir::createConvertHLFIRtoFIRPass()); diff --git a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt --- a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_flang_library(HLFIRTransforms BufferizeHLFIR.cpp ConvertToFIR.cpp + InlineElementals.cpp LowerHLFIRIntrinsics.cpp SimplifyHLFIRIntrinsics.cpp diff --git a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp @@ -0,0 +1,111 @@ +//===- InlineElementals.cpp - Inline chained hlfir.elemental ops ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Chained elemental operations like a + b + c can inline the first elemental +// at the hlfir.apply in the body of the second one (as described in +// docs/HighLevelFIR.md). This has to be done in a pass rather than in lowering +// so that it happens after the HLFIR intrinsic simplification pass. +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Builder/HLFIRTools.h" +#include "flang/Optimizer/Dialect/Support/FIRContext.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "flang/Optimizer/HLFIR/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/TypeSwitch.h" +#include + +namespace hlfir { +#define GEN_PASS_DEF_INLINEELEMENTALS +#include "flang/Optimizer/HLFIR/Passes.h.inc" +} // namespace hlfir + +namespace { + +/// If the elemental has only two uses and those two are an apply operation and +/// a destory operation, return those two, otherwise return {} +static std::optional> +getTwoUses(hlfir::ElementalOp elemental) { + mlir::Operation::user_range users = elemental->getUsers(); + // don't inline anything with more than one use (plus hfir.destroy) + if (std::distance(users.begin(), users.end()) != 2) { + return {}; + } + + hlfir::ApplyOp apply; + hlfir::DestroyOp destroy; + for (mlir::Operation *user : users) + mlir::TypeSwitch(user) + .Case([&](hlfir::ApplyOp op) { apply = op; }) + .Case([&](hlfir::DestroyOp op) { destroy = op; }); + + if (!apply || !destroy) + return {}; + return std::pair{apply, destroy}; +} + +class InlineElementalConversion + : public mlir::OpRewritePattern { +public: + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(hlfir::ElementalOp elemental, + mlir::PatternRewriter &rewriter) const override { + mlir::Location loc = elemental.getLoc(); + mlir::ModuleOp mod = elemental->getParentOfType(); + fir::FirOpBuilder builder(rewriter, fir::getKindMapping(mod)); + + // the option must not be {}, otherwise the op would already be legal + auto [apply, destroy] = *getTwoUses(elemental); + + builder.setInsertionPointAfter(apply); + hlfir::YieldElementOp yield = + hlfir::inlineElementalOp(loc, builder, elemental, apply.getIndices()); + + rewriter.replaceAllUsesWith(apply.getResult(), yield.getElementValue()); + rewriter.eraseOp(yield); + rewriter.eraseOp(apply); + rewriter.eraseOp(destroy); + rewriter.eraseOp(elemental); + + return mlir::success(); + } +}; + +class InlineElementalsPass + : public hlfir::impl::InlineElementalsBase { +public: + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + mlir::MLIRContext *context = &getContext(); + mlir::RewritePatternSet patterns(context); + patterns.insert(context); + + mlir::ConversionTarget target(*context); + target.markUnknownOpDynamicallyLegal( + [](mlir::Operation *) { return true; }); + target.addDynamicallyLegalOp( + [](hlfir::ElementalOp elemental) { return !getTwoUses(elemental); }); + + if (mlir::failed( + mlir::applyFullConversion(func, target, std::move(patterns)))) { + mlir::emitError(func->getLoc(), "failure in HLFIR elemental inlining"); + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr hlfir::createInlineElementalsPass() { + return std::make_unique(); +} diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90 --- a/flang/test/Driver/mlir-debug-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90 @@ -25,6 +25,8 @@ ! ALL: Pass statistics report ! ALL: Fortran::lower::VerifierPass +! ALL-NEXT: 'func.func' Pipeline +! ALL-NEXT: InlineElementals ! ALL-NEXT: LowerHLFIRIntrinsics ! ALL-NEXT: BufferizeHLFIR ! ALL-NEXT: ConvertHLFIRtoFIR diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90 --- a/flang/test/Driver/mlir-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-pass-pipeline.f90 @@ -15,6 +15,7 @@ ! O2-NEXT: Canonicalizer ! O2-NEXT: 'func.func' Pipeline ! O2-NEXT: SimplifyHLFIRIntrinsics +! ALL: InlineElementals ! ALL-NEXT: LowerHLFIRIntrinsics ! ALL-NEXT: BufferizeHLFIR ! ALL-NEXT: ConvertHLFIRtoFIR diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir --- a/flang/test/Fir/basic-program.fir +++ b/flang/test/Fir/basic-program.fir @@ -19,6 +19,7 @@ // PASSES: Canonicalizer // PASSES-NEXT: 'func.func' Pipeline // PASSES-NEXT: SimplifyHLFIRIntrinsics +// PASSES-NEXT: InlineElementals // PASSES-NEXT: LowerHLFIRIntrinsics // PASSES-NEXT: BufferizeHLFIR // PASSES-NEXT: ConvertHLFIRtoFIR diff --git a/flang/test/HLFIR/inline-elemental.fir b/flang/test/HLFIR/inline-elemental.fir new file mode 100644 --- /dev/null +++ b/flang/test/HLFIR/inline-elemental.fir @@ -0,0 +1,176 @@ +// RUN: fir-opt --inline-elementals %s | FileCheck %s + +// check inlining one elemental into another +// a = b * c + d +func.func @inline_to_elemental(%arg0: !fir.box> {fir.bindc_name = "a"}, %arg1: !fir.box> {fir.bindc_name = "b"}, %arg2: !fir.box> {fir.bindc_name = "c"}, %arg3: !fir.box> {fir.bindc_name = "d"}) { + %0:2 = hlfir.declare %arg0 {uniq_name = "a"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %1:2 = hlfir.declare %arg1 {uniq_name = "b"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %2:2 = hlfir.declare %arg2 {uniq_name = "c"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %3:2 = hlfir.declare %arg3 {uniq_name = "d"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %c0 = arith.constant 0 : index + %4:3 = fir.box_dims %1#0, %c0 : (!fir.box>, index) -> (index, index, index) + %5 = fir.shape %4#1 : (index) -> !fir.shape<1> + %6 = hlfir.elemental %5 : (!fir.shape<1>) -> !hlfir.expr { + ^bb0(%arg4: index): + %8 = hlfir.designate %1#0 (%arg4) : (!fir.box>, index) -> !fir.ref + %9 = hlfir.designate %2#0 (%arg4) : (!fir.box>, index) -> !fir.ref + %10 = fir.load %8 : !fir.ref + %11 = fir.load %9 : !fir.ref + %12 = arith.muli %10, %11 : i32 + hlfir.yield_element %12 : i32 + } + %7 = hlfir.elemental %5 : (!fir.shape<1>) -> !hlfir.expr { + ^bb0(%arg4: index): + %8 = hlfir.apply %6, %arg4 : (!hlfir.expr, index) -> i32 + %9 = hlfir.designate %3#0 (%arg4) : (!fir.box>, index) -> !fir.ref + %10 = fir.load %9 : !fir.ref + %11 = arith.addi %8, %10 : i32 + hlfir.yield_element %11 : i32 + } + hlfir.assign %7 to %0#0 : !hlfir.expr, !fir.box> + hlfir.destroy %7 : !hlfir.expr + hlfir.destroy %6 : !hlfir.expr + return +} +// CHECK-LABEL: func.func @inline_to_elemental +// CHECK-SAME: %[[A_ARG:.*]]: !fir.box> {fir.bindc_name = "a"} +// CHECK-SAME: %[[B_ARG:.*]]: !fir.box> {fir.bindc_name = "b"} +// CHECK-SAME: %[[C_ARG:.*]]: !fir.box> {fir.bindc_name = "c"} +// CHECK-SAME: %[[D_ARG:.*]]: !fir.box> {fir.bindc_name = "d"} +// CHECK-DAG: %[[A:.*]]:2 = hlfir.declare %[[A_ARG]] +// CHECK-DAG: %[[B:.*]]:2 = hlfir.declare %[[B_ARG]] +// CHECK-DAG: %[[C:.*]]:2 = hlfir.declare %[[C_ARG]] +// CHECK-DAG: %[[D:.*]]:2 = hlfir.declare %[[D_ARG]] +// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[B_DIM0:.*]]:3 = fir.box_dims %[[B]]#0, %[[C0]] +// CHECK-NEXT: %[[B_SHAPE:.*]] = fir.shape %[[B_DIM0]]#1 +// CHECK-NEXT: %[[EXPR:.*]] = hlfir.elemental %[[B_SHAPE]] +// CHECK-NEXT: ^bb0(%[[I:.*]]: index): +// inline the first elemental: +// CHECK-NEXT: %[[B_I_REF:.*]] = hlfir.designate %[[B]]#0 (%[[I]]) +// CHECK-NEXT: %[[C_I_REF:.*]] = hlfir.designate %[[C]]#0 (%[[I]]) +// CHECK-NEXT: %[[B_I:.*]] = fir.load %[[B_I_REF]] +// CHECK-NEXT: %[[C_I:.*]] = fir.load %[[C_I_REF]] +// CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[B_I]], %[[C_I]] +// second elemental: +// CHECK-NEXT: %[[D_I_REF:.*]] = hlfir.designate %[[D]]#0 (%[[I]]) +// CHECK-NEXT: %[[D_I:.*]] = fir.load %[[D_I_REF]] +// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[MUL]], %[[D_I]] +// CHECK-NEXT: hlfir.yield_element %[[ADD]] +// CHECK-NEXT: } +// CHECK-NEXT: hlfir.assign %[[EXPR]] to %[[A]]#0 +// CHECK-NEXT: hlfir.destroy %[[EXPR]] +// CHECK-NEXT: return +// CHECK-NEXT: } + +// check inlining into a do_loop +func.func @inline_to_loop(%arg0: !fir.box> {fir.bindc_name = "a"}, %arg1: !fir.box> {fir.bindc_name = "b"}, %arg2: !fir.box> {fir.bindc_name = "c"}, %arg3: !fir.box> {fir.bindc_name = "d"}) { + %0:2 = hlfir.declare %arg0 {uniq_name = "a"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %1:2 = hlfir.declare %arg1 {uniq_name = "b"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %2:2 = hlfir.declare %arg2 {uniq_name = "c"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %3:2 = hlfir.declare %arg3 {uniq_name = "d"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %c0 = arith.constant 0 : index + %4:3 = fir.box_dims %1#0, %c0 : (!fir.box>, index) -> (index, index, index) + %5 = fir.shape %4#1 : (index) -> !fir.shape<1> + %6 = hlfir.elemental %5 : (!fir.shape<1>) -> !hlfir.expr { + ^bb0(%arg4: index): + %8 = hlfir.designate %1#0 (%arg4) : (!fir.box>, index) -> !fir.ref + %9 = hlfir.designate %2#0 (%arg4) : (!fir.box>, index) -> !fir.ref + %10 = fir.load %8 : !fir.ref + %11 = fir.load %9 : !fir.ref + %12 = arith.muli %10, %11 : i32 + hlfir.yield_element %12 : i32 + } + %array = fir.array_load %0#0 : (!fir.box>) -> !fir.array + %c1 = arith.constant 1 : index + %max = arith.subi %4#1, %c1 : index + %7 = fir.do_loop %arg4 = %c0 to %max step %c1 unordered iter_args(%arg5 = %array) -> (!fir.array) { + %8 = hlfir.apply %6, %arg4 : (!hlfir.expr, index) -> i32 + %9 = hlfir.designate %3#0 (%arg4) : (!fir.box>, index) -> !fir.ref + %10 = fir.load %9 : !fir.ref + %11 = arith.addi %8, %10 : i32 + %12 = fir.array_update %arg5, %11, %arg4 : (!fir.array, i32, index) -> !fir.array + fir.result %12 : !fir.array + } + fir.array_merge_store %array, %7 to %arg0 : !fir.array, !fir.array, !fir.box> + hlfir.destroy %6 : !hlfir.expr + return +} +// CHECK-LABEL: func.func @inline_to_loop +// CHECK-SAME: %[[A_ARG:.*]]: !fir.box> {fir.bindc_name = "a"} +// CHECK-SAME: %[[B_ARG:.*]]: !fir.box> {fir.bindc_name = "b"} +// CHECK-SAME: %[[C_ARG:.*]]: !fir.box> {fir.bindc_name = "c"} +// CHECK-SAME: %[[D_ARG:.*]]: !fir.box> {fir.bindc_name = "d"} +// CHECK-DAG: %[[A:.*]]:2 = hlfir.declare %[[A_ARG]] +// CHECK-DAG: %[[B:.*]]:2 = hlfir.declare %[[B_ARG]] +// CHECK-DAG: %[[C:.*]]:2 = hlfir.declare %[[C_ARG]] +// CHECK-DAG: %[[D:.*]]:2 = hlfir.declare %[[D_ARG]] +// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[B_DIM0:.*]]:3 = fir.box_dims %[[B]]#0, %[[C0]] +// CHECK-NEXT: %[[B_SHAPE:.*]] = fir.shape %[[B_DIM0]]#1 +// CHECK-NEXT: %[[ARRAY:.*]] = fir.array_load %[[A]]#0 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 +// CHECK-NEXT: %[[MAX:.*]] = arith.subi %[[B_DIM0]]#1, %[[C1]] +// CHECK-NEXT: %[[LOOP:.*]] = fir.do_loop %[[I:.*]] = %[[C0]] to %[[MAX]] step %[[C1]] unordered iter_args(%[[LOOP_ARRAY:.*]] = %[[ARRAY]]) +// inline the elemental: +// CHECK-NEXT: %[[B_I_REF:.*]] = hlfir.designate %[[B]]#0 (%[[I]]) +// CHECK-NEXT: %[[C_I_REF:.*]] = hlfir.designate %[[C]]#0 (%[[I]]) +// CHECK-NEXT: %[[B_I:.*]] = fir.load %[[B_I_REF]] +// CHECK-NEXT: %[[C_I:.*]] = fir.load %[[C_I_REF]] +// CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[B_I]], %[[C_I]] +// loop body: +// CHECK-NEXT: %[[D_I_REF:.*]] = hlfir.designate %[[D]]#0 (%[[I]]) +// CHECK-NEXT: %[[D_I:.*]] = fir.load %[[D_I_REF]] +// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[MUL]], %[[D_I]] +// CHECK-NEXT: %[[ARRAY_UPD:.*]] = fir.array_update %[[LOOP_ARRAY]], %[[ADD]], %[[I]] +// CHECK-NEXT: fir.result %[[ARRAY_UPD]] +// CHECK-NEXT: } +// CHECK-NEXT: fir.array_merge_store %[[ARRAY]], %[[LOOP]] to %[[A_ARG]] +// CHECK-NEXT: return +// CHECK-NEXT: } + +// inlining into a single hlfir.apply +// a = (b * c)[1] +func.func @inline_to_apply(%arg0: !fir.ref {fir.bindc_name = "a"}, %arg1: !fir.box> {fir.bindc_name = "b"}, %arg2: !fir.box> {fir.bindc_name = "c"}) { + %0:2 = hlfir.declare %arg0 {uniq_name = "a"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %1:2 = hlfir.declare %arg1 {uniq_name = "b"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %2:2 = hlfir.declare %arg2 {uniq_name = "c"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %c0 = arith.constant 0 : index + %4:3 = fir.box_dims %1#0, %c0 : (!fir.box>, index) -> (index, index, index) + %5 = fir.shape %4#1 : (index) -> !fir.shape<1> + %6 = hlfir.elemental %5 : (!fir.shape<1>) -> !hlfir.expr { + ^bb0(%arg4: index): + %8 = hlfir.designate %1#0 (%arg4) : (!fir.box>, index) -> !fir.ref + %9 = hlfir.designate %2#0 (%arg4) : (!fir.box>, index) -> !fir.ref + %10 = fir.load %8 : !fir.ref + %11 = fir.load %9 : !fir.ref + %12 = arith.muli %10, %11 : i32 + hlfir.yield_element %12 : i32 + } + %c1 = arith.constant 1 : index + %res = hlfir.apply %6, %c1 : (!hlfir.expr, index) -> i32 + fir.store %res to %0#0 : !fir.ref + hlfir.destroy %6 : !hlfir.expr + return +} +// CHECK-LABEL: func.func @inline_to_apply +// CHECK-SAME: %[[A_ARG:.*]]: !fir.ref {fir.bindc_name = "a"} +// CHECK-SAME: %[[B_ARG:.*]]: !fir.box> {fir.bindc_name = "b"} +// CHECK-SAME: %[[C_ARG:.*]]: !fir.box> {fir.bindc_name = "c"} +// CHECK-DAG: %[[A:.*]]:2 = hlfir.declare %[[A_ARG]] +// CHECK-DAG: %[[B:.*]]:2 = hlfir.declare %[[B_ARG]] +// CHECK-DAG: %[[C:.*]]:2 = hlfir.declare %[[C_ARG]] +// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[B_DIM0:.*]]:3 = fir.box_dims %[[B]]#0, %[[C0]] +// CHECK-NEXT: %[[B_SHAPE:.*]] = fir.shape %[[B_DIM0]]#1 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// inline the elemental: +// CHECK-NEXT: %[[B_1_REF:.*]] = hlfir.designate %[[B]]#0 (%[[C1]]) +// CHECK-NEXT: %[[C_1_REF:.*]] = hlfir.designate %[[C]]#0 (%[[C1]]) +// CHECK-NEXT: %[[B_1:.*]] = fir.load %[[B_1_REF]] +// CHECK-NEXT: %[[C_1:.*]] = fir.load %[[C_1_REF]] +// CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[B_1]], %[[C_1]] +// store: +// CHECK-NEXT: fir.store %[[MUL]] to %0#0 : !fir.ref +// CHECK-NEXT: return +// CHECK-NEXT: }