Index: flang/include/flang/Optimizer/Builder/FIRBuilder.h =================================================================== --- flang/include/flang/Optimizer/Builder/FIRBuilder.h +++ flang/include/flang/Optimizer/Builder/FIRBuilder.h @@ -63,6 +63,8 @@ setFastMathFlags(fmi.getFastMathFlagsAttr().getValue()); } } + FirOpBuilder(mlir::OpBuilder &builder, mlir::Operation *op) + : FirOpBuilder(builder, fir::getKindMapping(op), op) {} // The listener self-reference has to be updated in case of copy-construction. FirOpBuilder(const FirOpBuilder &other) Index: flang/include/flang/Optimizer/Dialect/Support/FIRContext.h =================================================================== --- flang/include/flang/Optimizer/Dialect/Support/FIRContext.h +++ flang/include/flang/Optimizer/Dialect/Support/FIRContext.h @@ -22,6 +22,7 @@ namespace mlir { class ModuleOp; +class Operation; } // namespace mlir namespace fir { @@ -43,6 +44,12 @@ /// default. KindMapping getKindMapping(mlir::ModuleOp mod); +/// Get the KindMapping instance that is in effect for the specified +/// operation. The KindMapping is taken from the operation itself, +/// if the operation is a ModuleOp, or from its parent ModuleOp. +/// If a ModuleOp cannot be reached, the function returns default KindMapping. +KindMapping getKindMapping(mlir::Operation *op); + /// Helper for determining the target from the host, etc. Tools may use this /// function to provide a consistent interpretation of the `--target=` /// command-line option. Index: flang/lib/Optimizer/Dialect/Support/FIRContext.cpp =================================================================== --- flang/lib/Optimizer/Dialect/Support/FIRContext.cpp +++ flang/lib/Optimizer/Dialect/Support/FIRContext.cpp @@ -51,6 +51,15 @@ return fir::KindMapping(ctx); } +fir::KindMapping fir::getKindMapping(mlir::Operation *op) { + auto moduleOp = mlir::dyn_cast(op); + if (moduleOp) + return getKindMapping(moduleOp); + + moduleOp = op->getParentOfType(); + return getKindMapping(moduleOp); +} + std::string fir::determineTargetTriple(llvm::StringRef triple) { // Treat "" or "default" as stand-ins for the default machine. if (triple.empty() || triple == "default") Index: flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp =================================================================== --- flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -253,8 +253,7 @@ if (fir::isa_trivial(apply.getType())) { result = rewriter.create(loc, result); } else { - auto module = apply->getParentOfType(); - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + fir::FirOpBuilder builder(rewriter, apply.getOperation()); result = packageBufferizedExpr(loc, builder, hlfir::Entity{result}, false); } @@ -288,8 +287,7 @@ matchAndRewrite(hlfir::ConcatOp concat, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = concat->getLoc(); - auto module = concat->getParentOfType(); - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + fir::FirOpBuilder builder(rewriter, concat.getOperation()); assert(adaptor.getStrings().size() >= 2 && "must have at least two strings operands"); if (adaptor.getStrings().size() > 2) @@ -328,8 +326,7 @@ matchAndRewrite(hlfir::SetLengthOp setLength, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = setLength->getLoc(); - auto module = setLength->getParentOfType(); - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + fir::FirOpBuilder builder(rewriter, setLength.getOperation()); // Create a temp with the new length. hlfir::Entity string = getBufferizedExprStorage(adaptor.getString()); auto charType = hlfir::getFortranElementType(setLength.getType()); @@ -362,8 +359,7 @@ matchAndRewrite(hlfir::GetLengthOp getLength, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = getLength->getLoc(); - auto module = getLength->getParentOfType(); - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + fir::FirOpBuilder builder(rewriter, getLength.getOperation()); hlfir::Entity bufferizedExpr = getBufferizedExprStorage(adaptor.getExpr()); mlir::Value length = hlfir::genCharLength(loc, builder, bufferizedExpr); if (!length) @@ -436,8 +432,7 @@ matchAndRewrite(hlfir::AssociateOp associate, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = associate->getLoc(); - auto module = associate->getParentOfType(); - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + fir::FirOpBuilder builder(rewriter, associate.getOperation()); mlir::Value bufferizedExpr = getBufferizedExprStorage(adaptor.getSource()); const bool isTrivialValue = fir::isa_trivial(bufferizedExpr.getType()); @@ -577,8 +572,7 @@ matchAndRewrite(hlfir::EndAssociateOp endAssociate, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = endAssociate->getLoc(); - auto module = endAssociate->getParentOfType(); - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + fir::FirOpBuilder builder(rewriter, endAssociate.getOperation()); genFreeIfMustFree(loc, builder, adaptor.getVar(), adaptor.getMustFree()); rewriter.eraseOp(endAssociate); return mlir::success(); @@ -597,8 +591,7 @@ mlir::Location loc = destroy->getLoc(); hlfir::Entity bufferizedExpr = getBufferizedExprStorage(adaptor.getExpr()); if (!fir::isa_trivial(bufferizedExpr.getType())) { - auto module = destroy->getParentOfType(); - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + fir::FirOpBuilder builder(rewriter, destroy.getOperation()); mlir::Value mustFree = getBufferizedExprMustFreeFlag(adaptor.getExpr()); mlir::Value firBase = bufferizedExpr.getFirBase(); genFreeIfMustFree(loc, builder, firBase, mustFree); @@ -617,8 +610,7 @@ matchAndRewrite(hlfir::NoReassocOp noreassoc, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = noreassoc->getLoc(); - auto module = noreassoc->getParentOfType(); - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + fir::FirOpBuilder builder(rewriter, noreassoc.getOperation()); mlir::Value bufferizedExpr = getBufferizedExprStorage(adaptor.getVal()); mlir::Value result = builder.create(loc, bufferizedExpr); @@ -677,8 +669,7 @@ matchAndRewrite(hlfir::ElementalOp elemental, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = elemental->getLoc(); - auto module = elemental->getParentOfType(); - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + fir::FirOpBuilder builder(rewriter, elemental.getOperation()); // The body of the elemental op may contain operation that will require // to be translated. Notify the rewriter about the cloned operations. HLFIRListener listener{builder, rewriter}; @@ -743,11 +734,10 @@ matchAndRewrite(hlfir::CharExtremumOp char_extremum, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = char_extremum->getLoc(); - auto module = char_extremum->getParentOfType(); auto predicate = char_extremum.getPredicate(); bool predIsMin = predicate == hlfir::CharExtremumPredicate::min ? true : false; - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + fir::FirOpBuilder builder(rewriter, char_extremum.getOperation()); assert(adaptor.getStrings().size() >= 2 && "must have at least two strings operands"); auto numOperands = adaptor.getStrings().size(); Index: flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp =================================================================== --- flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp +++ flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp @@ -240,8 +240,7 @@ matchAndRewrite(hlfir::CopyInOp copyInOp, mlir::PatternRewriter &rewriter) const override { mlir::Location loc = copyInOp.getLoc(); - auto module = copyInOp->getParentOfType(); - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + fir::FirOpBuilder builder(rewriter, copyInOp.getOperation()); CopyInResult result = copyInOp.getVarIsPresent() ? genOptionalCopyIn(loc, builder, copyInOp) : genNonOptionalCopyIn(loc, builder, copyInOp); @@ -259,8 +258,7 @@ matchAndRewrite(hlfir::CopyOutOp copyOutOp, mlir::PatternRewriter &rewriter) const override { mlir::Location loc = copyOutOp.getLoc(); - auto module = copyOutOp->getParentOfType(); - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + fir::FirOpBuilder builder(rewriter, copyOutOp.getOperation()); builder.genIfThen(loc, copyOutOp.getWasCopied()) .genThen([&]() { @@ -323,8 +321,7 @@ mlir::Value hlfirBase; mlir::Type hlfirBaseType = declareOp.getBase().getType(); if (hlfirBaseType.isa()) { - auto module = declareOp->getParentOfType(); - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + fir::FirOpBuilder builder(rewriter, declareOp.getOperation()); // Helper to generate the hlfir fir.box with the local lower bounds and // type parameters. auto genHlfirBox = [&]() -> mlir::Value { @@ -423,8 +420,7 @@ matchAndRewrite(hlfir::DesignateOp designate, mlir::PatternRewriter &rewriter) const override { mlir::Location loc = designate.getLoc(); - auto module = designate->getParentOfType(); - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + fir::FirOpBuilder builder(rewriter, designate.getOperation()); hlfir::Entity baseEntity(designate.getMemref()); Index: flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp =================================================================== --- flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp +++ flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp @@ -14,7 +14,6 @@ #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/HLFIRTools.h" #include "flang/Optimizer/Dialect/Support/FIRContext.h" -#include "flang/Optimizer/Dialect/Support/KindMapping.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/HLFIR/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -91,8 +90,7 @@ assert(elemental.getRegion().hasOneBlock() && "expect elemental region to have one block"); - fir::FirOpBuilder builder{rewriter, - fir::KindMapping{rewriter.getContext()}}; + fir::FirOpBuilder builder{rewriter, elemental.getOperation()}; builder.setInsertionPointAfter(apply); hlfir::YieldElementOp yield = hlfir::inlineElementalOp( elemental.getLoc(), builder, elemental, apply.getIndices()); Index: flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp =================================================================== --- flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp +++ flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp @@ -69,8 +69,7 @@ mlir::PatternRewriter &rewriter, const fir::IntrinsicArgumentLoweringRules *argLowering) const { mlir::Location loc = op->getLoc(); - fir::KindMapping kindMapping{rewriter.getContext()}; - fir::FirOpBuilder builder{rewriter, kindMapping, op}; + fir::FirOpBuilder builder{rewriter, op}; llvm::SmallVector ret; llvm::SmallVector, 2> cleanupFns; @@ -229,8 +228,7 @@ return mlir::failure(); } - fir::KindMapping kindMapping{rewriter.getContext()}; - fir::FirOpBuilder builder{rewriter, kindMapping, operation}; + fir::FirOpBuilder builder{rewriter, operation.getOperation()}; const mlir::Location &loc = operation->getLoc(); mlir::Type i32 = builder.getI32Type(); @@ -271,8 +269,7 @@ mlir::LogicalResult matchAndRewrite(hlfir::CountOp count, mlir::PatternRewriter &rewriter) const override { - fir::KindMapping kindMapping{rewriter.getContext()}; - fir::FirOpBuilder builder{rewriter, kindMapping, count}; + fir::FirOpBuilder builder{rewriter, count.getOperation()}; const mlir::Location &loc = count->getLoc(); mlir::Type i32 = builder.getI32Type(); @@ -304,8 +301,7 @@ mlir::LogicalResult matchAndRewrite(hlfir::MatmulOp matmul, mlir::PatternRewriter &rewriter) const override { - fir::KindMapping kindMapping{rewriter.getContext()}; - fir::FirOpBuilder builder{rewriter, kindMapping, matmul}; + fir::FirOpBuilder builder{rewriter, matmul.getOperation()}; const mlir::Location &loc = matmul->getLoc(); mlir::Value lhs = matmul.getLhs(); @@ -336,8 +332,7 @@ mlir::LogicalResult matchAndRewrite(hlfir::DotProductOp dotProduct, mlir::PatternRewriter &rewriter) const override { - fir::KindMapping kindMapping{rewriter.getContext()}; - fir::FirOpBuilder builder{rewriter, kindMapping, dotProduct}; + fir::FirOpBuilder builder{rewriter, dotProduct.getOperation()}; const mlir::Location &loc = dotProduct->getLoc(); mlir::Value lhs = dotProduct.getLhs(); @@ -368,8 +363,7 @@ mlir::LogicalResult matchAndRewrite(hlfir::TransposeOp transpose, mlir::PatternRewriter &rewriter) const override { - fir::KindMapping kindMapping{rewriter.getContext()}; - fir::FirOpBuilder builder{rewriter, kindMapping, transpose}; + fir::FirOpBuilder builder{rewriter, transpose.getOperation()}; const mlir::Location &loc = transpose->getLoc(); mlir::Value arg = transpose.getArray(); @@ -399,8 +393,7 @@ mlir::LogicalResult matchAndRewrite(hlfir::MatmulTransposeOp multranspose, mlir::PatternRewriter &rewriter) const override { - fir::KindMapping kindMapping{rewriter.getContext()}; - fir::FirOpBuilder builder{rewriter, kindMapping, multranspose}; + fir::FirOpBuilder builder{rewriter, multranspose.getOperation()}; const mlir::Location &loc = multranspose->getLoc(); mlir::Value lhs = multranspose.getLhs(); Index: flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp =================================================================== --- flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -13,7 +13,6 @@ #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/HLFIRTools.h" #include "flang/Optimizer/Dialect/FIRDialect.h" -#include "flang/Optimizer/Dialect/Support/KindMapping.h" #include "flang/Optimizer/HLFIR/HLFIRDialect.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/HLFIR/Passes.h" @@ -40,8 +39,7 @@ matchAndRewrite(hlfir::TransposeOp transpose, mlir::PatternRewriter &rewriter) const override { mlir::Location loc = transpose.getLoc(); - fir::KindMapping kindMapping{rewriter.getContext()}; - fir::FirOpBuilder builder{rewriter, kindMapping}; + fir::FirOpBuilder builder{rewriter, transpose.getOperation()}; hlfir::ExprType expr = transpose.getType(); mlir::Type elementType = expr.getElementType(); hlfir::Entity array = hlfir::Entity{transpose.getArray()}; Index: flang/test/HLFIR/count-lowering-default-int-kinds.fir =================================================================== --- /dev/null +++ flang/test/HLFIR/count-lowering-default-int-kinds.fir @@ -0,0 +1,47 @@ +// Test hlfir.count operation lowering with different default integer kinds. +// RUN: fir-opt %s -lower-hlfir-intrinsics | FileCheck %s + +module attributes {fir.defaultkind = "a1c4d8i8l4r4", fir.kindmap = ""} { + func.func @test_i8(%arg0: !fir.box>> {fir.bindc_name = "x"}, %arg1: i64) { + %4 = hlfir.count %arg0 dim %arg1 : (!fir.box>>, i64) -> !hlfir.expr + return + } +} +// CHECK-LABEL: func.func @test_i8 +// CHECK: %[[KIND:.*]] = arith.constant 8 : index +// CHECK: %[[KIND_ARG:.*]] = fir.convert %[[KIND]] : (index) -> i32 +// CHECK: fir.call @_FortranACountDim(%{{.*}}, %{{.*}}, %{{.*}}, %[[KIND_ARG]], %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.box, i32, i32, !fir.ref, i32) -> none + +module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = ""} { + func.func @test_i4(%arg0: !fir.box>> {fir.bindc_name = "x"}, %arg1: i64) { + %4 = hlfir.count %arg0 dim %arg1 : (!fir.box>>, i64) -> !hlfir.expr + return + } +} +// CHECK-LABEL: func.func @test_i4 +// CHECK: %[[KIND:.*]] = arith.constant 4 : index +// CHECK: %[[KIND_ARG:.*]] = fir.convert %[[KIND]] : (index) -> i32 +// CHECK: fir.call @_FortranACountDim(%{{.*}}, %{{.*}}, %{{.*}}, %[[KIND_ARG]], %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.box, i32, i32, !fir.ref, i32) -> none + +module attributes {fir.defaultkind = "a1c4d8i2l4r4", fir.kindmap = ""} { + func.func @test_i2(%arg0: !fir.box>> {fir.bindc_name = "x"}, %arg1: i64) { + %4 = hlfir.count %arg0 dim %arg1 : (!fir.box>>, i64) -> !hlfir.expr + return + } +} +// CHECK-LABEL: func.func @test_i2 +// CHECK: %[[KIND:.*]] = arith.constant 2 : index +// CHECK: %[[KIND_ARG:.*]] = fir.convert %[[KIND]] : (index) -> i32 +// CHECK: fir.call @_FortranACountDim(%{{.*}}, %{{.*}}, %{{.*}}, %[[KIND_ARG]], %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.box, i32, i32, !fir.ref, i32) -> none + +module attributes {fir.defaultkind = "a1c4d8i1l4r4", fir.kindmap = ""} { + func.func @test_i1(%arg0: !fir.box>> {fir.bindc_name = "x"}, %arg1: i64) { + %4 = hlfir.count %arg0 dim %arg1 : (!fir.box>>, i64) -> !hlfir.expr + return + } +} +// CHECK-LABEL: func.func @test_i1 +// CHECK: arith.constant 1 : index +// CHECK: %[[KIND:.*]] = arith.constant 1 : index +// CHECK: %[[KIND_ARG:.*]] = fir.convert %[[KIND]] : (index) -> i32 +// CHECK: fir.call @_FortranACountDim(%{{.*}}, %{{.*}}, %{{.*}}, %[[KIND_ARG]], %{{.*}}, %{{.*}}) : (!fir.ref>, !fir.box, i32, i32, !fir.ref, i32) -> none Index: flang/test/HLFIR/count-lowering.fir =================================================================== --- flang/test/HLFIR/count-lowering.fir +++ flang/test/HLFIR/count-lowering.fir @@ -39,6 +39,7 @@ // CHECK-DAG: %[[RES:.*]]:2 = hlfir.declare %[[ARG1]] // CHECK-DAG: %[[RET_BOX:.*]] = fir.alloca !fir.box>> +// CHECK-DAG: %[[KIND:.*]] = arith.constant 4 : index // CHECK-DAG: %[[RET_ADDR:.*]] = fir.zero_bits !fir.heap> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[RET_SHAPE:.*]] = fir.shape %[[C0]] : (index) -> !fir.shape<1> @@ -48,8 +49,9 @@ // CHECK-DAG: %[[DIM:.*]] = fir.load %[[DIM_VAR]]#0 : !fir.ref // CHECK-DAG: %[[RET_ARG:.*]] = fir.convert %[[RET_BOX]] // CHECK-DAG: %[[MASK_ARG:.*]] = fir.convert %[[MASK]]#1 +// CHECK-DAG: %[[KIND_ARG:.*]] = fir.convert %[[KIND]] : (index) -> i32 -// CHECK: %[[NONE:.*]] = fir.call @_FortranACountDim(%[[RET_ARG]], %[[MASK_ARG]], %[[DIM]], %[[LOC_STR:.*]], %[[LOC_N:.*]]) : (!fir.ref>, !fir.box, i32, i32, !fir.ref, i32) -> none +// CHECK: %[[NONE:.*]] = fir.call @_FortranACountDim(%[[RET_ARG]], %[[MASK_ARG]], %[[DIM]], %[[KIND_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]]) : (!fir.ref>, !fir.box, i32, i32, !fir.ref, i32) -> none // CHECK: %[[RET:.*]] = fir.load %[[RET_BOX]] // CHECK: %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[RET]] // CHECK-NEXT: %[[ADDR:.*]] = fir.box_addr %[[RET]] @@ -80,6 +82,7 @@ // CHECK-LABEL: func.func @_QPcount3( // CHECK: %[[ARG0:.*]]: !fir.ref> // CHECK-DAG: %[[RET_BOX:.*]] = fir.alloca !fir.box>> +// CHECK-DAG: %[[KIND:.*]] = arith.constant 4 : index // CHECK-DAG: %[[RET_ADDR:.*]] = fir.zero_bits !fir.heap> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[RET_SHAPE:.*]] = fir.shape %[[C0]] : (index) -> !fir.shape<1> @@ -95,7 +98,9 @@ // CHECK-DAG: %[[RET_ARG:.*]] = fir.convert %[[RET_BOX]] // CHECK-DAG: %[[MASK_ARG:.*]] = fir.convert %[[MASK_BOX]] : (!fir.box>>) -> !fir.box -// CHECK: %[[NONE:.*]] = fir.call @_FortranACountDim(%[[RET_ARG]], %[[MASK_ARG]], %[[DIM]], %[[LOC_STR:.*]], %[[LOC_N:.*]]) +// CHECK-DAG: %[[KIND_ARG:.*]] = fir.convert %[[KIND]] : (index) -> i32 + +// CHECK: %[[NONE:.*]] = fir.call @_FortranACountDim(%[[RET_ARG]], %[[MASK_ARG]], %[[DIM]], %[[KIND_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]]) // CHECK: %[[RET:.*]] = fir.load %[[RET_BOX]] // CHECK: %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[RET]] // CHECK-NEXT: %[[ADDR:.*]] = fir.box_addr %[[RET]]