diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h --- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h +++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h @@ -315,6 +315,11 @@ fir::FirOpBuilder &builder, mlir::Value shape); +/// Return explicit extents. If the base is a fir.box, this won't read it to +/// return the extents and will instead return an empty vector. +llvm::SmallVector +getExplicitExtentsFromShape(mlir::Value shape, fir::FirOpBuilder &builder); + /// Read length parameters into result if this entity has any. void genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder, Entity entity, diff --git a/flang/include/flang/Optimizer/HLFIR/Passes.h b/flang/include/flang/Optimizer/HLFIR/Passes.h --- a/flang/include/flang/Optimizer/HLFIR/Passes.h +++ b/flang/include/flang/Optimizer/HLFIR/Passes.h @@ -13,6 +13,7 @@ #ifndef FORTRAN_OPTIMIZER_HLFIR_PASSES_H #define FORTRAN_OPTIMIZER_HLFIR_PASSES_H +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" #include @@ -24,6 +25,7 @@ std::unique_ptr createConvertHLFIRtoFIRPass(); std::unique_ptr createBufferizeHLFIRPass(); std::unique_ptr createLowerHLFIRIntrinsicsPass(); +std::unique_ptr createSimplifyHLFIRIntrinsicsPass(); #define GEN_PASS_REGISTRATION #include "flang/Optimizer/HLFIR/Passes.h.inc" diff --git a/flang/include/flang/Optimizer/HLFIR/Passes.td b/flang/include/flang/Optimizer/HLFIR/Passes.td --- a/flang/include/flang/Optimizer/HLFIR/Passes.td +++ b/flang/include/flang/Optimizer/HLFIR/Passes.td @@ -25,4 +25,9 @@ let constructor = "hlfir::createLowerHLFIRIntrinsicsPass()"; } +def SimplifyHLFIRIntrinsics : Pass<"simplify-hlfir-intrinsics", "::mlir::func::FuncOp"> { + let summary = "Simplify HLFIR intrinsic operations that don't need to result in runtime calls"; + let constructor = "hlfir::createSimplifyHLFIRIntrinsicsPass()"; +} + #endif //FORTRAN_DIALECT_HLFIR_PASSES diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc --- a/flang/include/flang/Tools/CLOptions.inc +++ b/flang/include/flang/Tools/CLOptions.inc @@ -238,8 +238,10 @@ /// passes pipeline inline void createHLFIRToFIRPassPipeline( mlir::PassManager &pm, llvm::OptimizationLevel optLevel = defaultOptLevel) { - if (optLevel.isOptimizingForSpeed()) + if (optLevel.isOptimizingForSpeed()) { addCanonicalizerPassWithoutRegionSimplification(pm); + pm.addPass(hlfir::createSimplifyHLFIRIntrinsicsPass()); + } pm.addPass(hlfir::createLowerHLFIRIntrinsicsPass()); pm.addPass(hlfir::createBufferizeHLFIRPass()); pm.addPass(hlfir::createConvertHLFIRtoFIRPass()); diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp --- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp +++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -23,8 +23,9 @@ // Return explicit extents. If the base is a fir.box, this won't read it to // return the extents and will instead return an empty vector. -static llvm::SmallVector -getExplicitExtentsFromShape(mlir::Value shape, fir::FirOpBuilder &builder) { +llvm::SmallVector +hlfir::getExplicitExtentsFromShape(mlir::Value shape, + fir::FirOpBuilder &builder) { llvm::SmallVector result; auto *shapeOp = shape.getDefiningOp(); if (auto s = mlir::dyn_cast_or_null(shapeOp)) { @@ -62,7 +63,7 @@ getExplicitExtents(fir::FortranVariableOpInterface var, fir::FirOpBuilder &builder) { if (mlir::Value shape = var.getShape()) - return getExplicitExtentsFromShape(var.getShape(), builder); + return hlfir::getExplicitExtentsFromShape(var.getShape(), builder); return {}; } @@ -404,7 +405,7 @@ assert((shape.getType().isa() || shape.getType().isa()) && "shape must contain extents"); - auto extents = getExplicitExtentsFromShape(shape, builder); + auto extents = hlfir::getExplicitExtentsFromShape(shape, builder); auto lowers = getExplicitLboundsFromShape(shape); assert(lowers.empty() || lowers.size() == extents.size()); mlir::Type idxTy = builder.getIndexType(); @@ -527,7 +528,7 @@ hlfir::getIndexExtents(mlir::Location loc, fir::FirOpBuilder &builder, mlir::Value shape) { llvm::SmallVector extents = - getExplicitExtentsFromShape(shape, builder); + hlfir::getExplicitExtentsFromShape(shape, builder); mlir::Type indexType = builder.getIndexType(); for (auto &extent : extents) extent = builder.createConvert(loc, indexType, extent); @@ -538,7 +539,7 @@ hlfir::Entity entity, unsigned dim) { entity = followShapeInducingSource(entity); if (auto shape = tryRetrievingShapeOrShift(entity)) { - auto extents = getExplicitExtentsFromShape(shape, builder); + auto extents = hlfir::getExplicitExtentsFromShape(shape, builder); if (!extents.empty()) { assert(extents.size() > dim && "bad inquiry"); return extents[dim]; diff --git a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt --- a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ BufferizeHLFIR.cpp ConvertToFIR.cpp LowerHLFIRIntrinsics.cpp + SimplifyHLFIRIntrinsics.cpp DEPENDS FIRDialect diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -0,0 +1,114 @@ +//===- SimplifyHLFIRIntrinsics.cpp - Simplify HLFIR Intrinsics ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Normally transformational intrinsics are lowered to calls to runtime +// functions. However, some cases of the intrinsics are faster when inlined +// into the calling function. +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Builder/HLFIRTools.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/Support/KindMapping.h" +#include "flang/Optimizer/HLFIR/HLFIRDialect.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "flang/Optimizer/HLFIR/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/Location.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace hlfir { +#define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS +#include "flang/Optimizer/HLFIR/Passes.h.inc" +} // namespace hlfir + +namespace { + +class TransposeAsElementalConversion + : public mlir::OpRewritePattern { +public: + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(hlfir::TransposeOp transpose, + mlir::PatternRewriter &rewriter) const override { + mlir::Location loc = transpose.getLoc(); + fir::KindMapping kindMapping{rewriter.getContext()}; + fir::FirOpBuilder builder{rewriter, kindMapping}; + hlfir::ExprType expr = transpose.getType(); + mlir::Type elementType = expr.getElementType(); + hlfir::Entity array = hlfir::Entity{transpose.getArray()}; + mlir::Value resultShape = genResultShape(loc, builder, array); + llvm::SmallVector typeParams; + hlfir::genLengthParameters(loc, builder, array, typeParams); + + auto genKernel = [&array](mlir::Location loc, fir::FirOpBuilder &builder, + mlir::ValueRange inputIndices) -> hlfir::Entity { + assert(inputIndices.size() == 2 && "checked in TransposeOp::validate"); + mlir::ValueRange transposedIndices{{inputIndices[1], inputIndices[0]}}; + hlfir::Entity element = + hlfir::getElementAt(loc, builder, array, transposedIndices); + hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, element); + return val; + }; + hlfir::ElementalOp elementalOp = hlfir::genElementalOp( + loc, builder, elementType, resultShape, typeParams, genKernel); + + rewriter.replaceOp(transpose, elementalOp.getResult()); + return mlir::success(); + } + +private: + static mlir::Value genResultShape(mlir::Location loc, + fir::FirOpBuilder &builder, + hlfir::Entity array) { + mlir::Value inShape = hlfir::genShape(loc, builder, array); + llvm::SmallVector inExtents = + hlfir::getExplicitExtentsFromShape(inShape, builder); + if (inShape.getUses().empty()) + inShape.getDefiningOp()->erase(); + + // transpose indices + assert(inExtents.size() == 2 && "checked in TransposeOp::validate"); + return builder.create( + loc, mlir::ValueRange{inExtents[1], inExtents[0]}); + } +}; + +class SimplifyHLFIRIntrinsics + : public hlfir::impl::SimplifyHLFIRIntrinsicsBase { +public: + void runOnOperation() override { + mlir::func::FuncOp func = this->getOperation(); + mlir::MLIRContext *context = &getContext(); + mlir::RewritePatternSet patterns(context); + patterns.insert(context); + mlir::ConversionTarget target(*context); + // don't transform transpose of polymorphic arrays (not currently supported + // by hlfir.elemental) + target.addDynamicallyLegalOp( + [](hlfir::TransposeOp transpose) { + return transpose.getType().cast().isPolymorphic(); + }); + target.markUnknownOpDynamicallyLegal( + [](mlir::Operation *) { return true; }); + if (mlir::failed( + mlir::applyFullConversion(func, target, std::move(patterns)))) { + mlir::emitError(func->getLoc(), + "failure in HLFIR intrinsic simplification"); + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr hlfir::createSimplifyHLFIRIntrinsicsPass() { + return std::make_unique(); +} diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90 --- a/flang/test/Driver/mlir-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-pass-pipeline.f90 @@ -13,6 +13,8 @@ ! ALL: Fortran::lower::VerifierPass ! O2-NEXT: Canonicalizer +! O2-NEXT: 'func.func' Pipeline +! O2-NEXT: SimplifyHLFIRIntrinsics ! ALL-NEXT: LowerHLFIRIntrinsics ! ALL-NEXT: BufferizeHLFIR ! ALL-NEXT: ConvertHLFIRtoFIR diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir --- a/flang/test/Fir/basic-program.fir +++ b/flang/test/Fir/basic-program.fir @@ -17,6 +17,8 @@ // PASSES: Pass statistics report // PASSES: Canonicalizer +// PASSES-NEXT: 'func.func' Pipeline +// PASSES-NEXT: SimplifyHLFIRIntrinsics // PASSES-NEXT: LowerHLFIRIntrinsics // PASSES-NEXT: BufferizeHLFIR // PASSES-NEXT: ConvertHLFIRtoFIR diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics.fir new file mode 100644 --- /dev/null +++ b/flang/test/HLFIR/simplify-hlfir-intrinsics.fir @@ -0,0 +1,95 @@ +// RUN: fir-opt --simplify-hlfir-intrinsics %s | FileCheck %s + +// box with known extents +func.func @transpose0(%arg0: !fir.box>) { + %res = hlfir.transpose %arg0 : (!fir.box>) -> !hlfir.expr<2x1xi32> + return +} +// CHECK-LABEL: func.func @transpose0( +// CHECK-SAME: %[[ARG0:.*]]: !fir.box>) { +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[SHAPE:.*]] = fir.shape %[[C2]], %[[C1]] : (index, index) -> !fir.shape<2> +// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[SHAPE]] : (!fir.shape<2>) -> !hlfir.expr<2x1xi32> { +// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index): +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[ARG0]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[C1_1:.*]] = arith.constant 1 : index +// CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[ARG0]], %[[C1_1]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[C1_2:.*]] = arith.constant 1 : index +// CHECK: %[[LOWER_BOUND0:.*]] = arith.subi %[[DIMS0]]#0, %[[C1_2]] : index +// CHECK: %[[J_OFFSET:.*]] = arith.addi %[[J]], %[[LOWER_BOUND0]] : index +// CHECK: %[[LOWER_BOUND1:.*]] = arith.subi %[[DIMS1]]#0, %[[C1_2]] : index +// CHECK: %[[I_OFFSET:.*]] = arith.addi %[[I]], %[[LOWER_BOUND1]] : index +// CHECK: %[[ELEMENT_REF:.*]] = hlfir.designate %[[ARG0]] (%[[J_OFFSET]], %[[I_OFFSET]]) : (!fir.box>, index, index) -> !fir.ref +// CHECK: %[[ELEMENT:.*]] = fir.load %[[ELEMENT_REF]] : !fir.ref +// CHECK: hlfir.yield_element %[[ELEMENT]] : i32 +// CHECK: } +// CHECK: return +// CHECK: } + +// expr with known extents +func.func @transpose1(%arg0: !hlfir.expr<1x2xi32>) { + %res = hlfir.transpose %arg0 : (!hlfir.expr<1x2xi32>) -> !hlfir.expr<2x1xi32> + return +} +// CHECK-LABEL: func.func @transpose1( +// CHECK-SAME: %[[ARG0:.*]]: !hlfir.expr<1x2xi32>) { +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[SHAPE:.*]] = fir.shape %[[C2]], %[[C1]] : (index, index) -> !fir.shape<2> +// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[SHAPE]] : (!fir.shape<2>) -> !hlfir.expr<2x1xi32> { +// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index): +// CHECK: %[[ELEMENT:.*]] = hlfir.apply %[[ARG0]], %[[J]], %[[I]] : (!hlfir.expr<1x2xi32>, index, index) -> i32 +// CHECK: hlfir.yield_element %[[ELEMENT]] : i32 +// CHECK: } +// CHECK: return +// CHECK: } + +// box with unknown extent +func.func @transpose2(%arg0: !fir.box>) { + %res = hlfir.transpose %arg0 : (!fir.box>) -> !hlfir.expr<2x?xi32> + return +} +// CHECK-LABEL: func.func @transpose2( +// CHECK-SAME: %[[ARG0:.*]]: !fir.box>) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[ARG0]], %[[C0]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[SHAPE:.*]] = fir.shape %[[C2]], %[[DIMS0]]#1 : (index, index) -> !fir.shape<2> +// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[SHAPE]] : (!fir.shape<2>) -> !hlfir.expr<2x?xi32> { +// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index): +// CHECK: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[ARG0]], %[[C0_1]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[C1_1:.*]] = arith.constant 1 : index +// CHECK: %[[DIMS1_1:.*]]:3 = fir.box_dims %[[ARG0]], %[[C1_1]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[C1_2:.*]] = arith.constant 1 : index +// CHECK: %[[LOWER_BOUND0:.*]] = arith.subi %[[DIMS0]]#0, %[[C1_2]] : index +// CHECK: %[[J_OFFSET:.*]] = arith.addi %[[J]], %[[LOWER_BOUND0]] : index +// CHECK: %[[LOWER_BOUND1:.*]] = arith.subi %[[DIMS1_1]]#0, %[[C1_2]] : index +// CHECK: %[[I_OFFSET:.*]] = arith.addi %[[I]], %[[LOWER_BOUND1]] : index +// CHECK: %[[ELE_REF:.*]] = hlfir.designate %[[ARG0]] (%[[J_OFFSET]], %[[I_OFFSET]]) : (!fir.box>, index, index) -> !fir.ref +// CHECK: %[[ELEMENT:.*]] = fir.load %[[ELE_REF]] : !fir.ref +// CHECK: hlfir.yield_element %[[ELEMENT]] : i32 +// CHECK: } +// CHECK: return +// CHECK: } + +// expr with unknown extent +func.func @transpose3(%arg0: !hlfir.expr) { + %res = hlfir.transpose %arg0 : (!hlfir.expr) -> !hlfir.expr<2x?xi32> + return +} +// CHECK-LABEL: func.func @transpose3( +// CHECK-SAME: %[[ARG0:.*]]: !hlfir.expr) { +// CHECK: %[[IN_SHAPE:.*]] = hlfir.shape_of %[[ARG0]] : (!hlfir.expr) -> !fir.shape<2> +// CHECK: %[[EXTENT0:.*]] = hlfir.get_extent %[[IN_SHAPE]] {dim = 0 : index} : (!fir.shape<2>) -> index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[OUT_SHAPE:.*]] = fir.shape %[[C2]], %[[EXTENT0]] : (index, index) -> !fir.shape<2> +// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[OUT_SHAPE]] : (!fir.shape<2>) -> !hlfir.expr<2x?xi32> { +// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index): +// CHECK: %[[ELEMENT:.*]] = hlfir.apply %[[ARG0]], %[[J]], %[[I]] : (!hlfir.expr, index, index) -> i32 +// CHECK: hlfir.yield_element %[[ELEMENT]] : i32 +// CHECK: } +// CHECK: return +// CHECK: }