diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp --- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp +++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp @@ -67,42 +67,37 @@ } // namespace -/// Generate function type for the simplified version of FortranASum -/// operating on the given \p elementType. -static mlir::FunctionType genFortranASumType(fir::FirOpBuilder &builder, - const mlir::Type &elementType) { +/// Generate function type for the simplified version of FortranASum and +/// similar functions with a fir.box type returning \p elementType. +static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder, + const mlir::Type &elementType) { mlir::Type boxType = fir::BoxType::get(builder.getNoneType()); return mlir::FunctionType::get(builder.getContext(), {boxType}, {elementType}); } -/// Generate function body of the simplified version of FortranASum -/// with signature provided by \p funcOp. The caller is responsible -/// for saving/restoring the original insertion point of \p builder. -/// \p funcOp is expected to be empty on entry to this function. -static void genFortranASumBody(fir::FirOpBuilder &builder, - mlir::func::FuncOp &funcOp) { - // function FortranASum_simplified(arr) - // T, dimension(:) :: arr - // T sum = 0 - // integer iter - // do iter = 0, extent(arr) - // sum = sum + arr[iter] - // end do - // FortranASum_simplified = sum - // end function FortranASum_simplified +using BodyOpGeneratorTy = + std::function; +using InitValGeneratorTy = std::function; + +/// Generate the reduction loop into \p funcOp. +/// +/// \p initVal is a function, called to get the initial value for +/// the reduction value +/// \p genBody is called to fill in the actual reduciton operation +/// for example add for SUM, MAX for MAXVAL, etc. +static void genReductionLoop(fir::FirOpBuilder &builder, + mlir::func::FuncOp &funcOp, + InitValGeneratorTy initVal, + BodyOpGeneratorTy genBody) { auto loc = mlir::UnknownLoc::get(builder.getContext()); mlir::Type elementType = funcOp.getResultTypes()[0]; builder.setInsertionPointToEnd(funcOp.addEntryBlock()); mlir::IndexType idxTy = builder.getIndexType(); - mlir::Value zero = elementType.isa() - ? builder.createRealConstant(loc, elementType, 0.0) - : builder.createIntegerConstant(loc, elementType, 0); - mlir::Value sum = builder.create(loc, elementType); - builder.create(loc, zero, sum); - mlir::Block::BlockArgListType args = funcOp.front().getArguments(); mlir::Value arg = args[0]; @@ -120,7 +115,11 @@ // We use C indexing here, so len-1 as loopcount mlir::Value loopCount = builder.create(loc, len, one); - auto loop = builder.create(loc, zeroIdx, loopCount, step); + mlir::Value init = initVal(builder, loc, elementType); + auto loop = builder.create(loc, zeroIdx, loopCount, step, + /*unordered=*/false, + /*finalCountValue=*/false, init); + mlir::Value reductionVal = loop.getRegionIterArgs()[0]; // Begin loop code mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint(); @@ -131,24 +130,83 @@ mlir::Value addr = builder.create(loc, eleRefTy, array, index); mlir::Value elem = builder.create(loc, addr); - mlir::Value sumVal = builder.create(loc, sum); - mlir::Value res; - if (elementType.isa()) - res = builder.create(loc, elem, sumVal); - else if (elementType.isa()) - res = builder.create(loc, elem, sumVal); - else - TODO(loc, "Unsupported type"); + reductionVal = genBody(builder, loc, elementType, elem, reductionVal); - builder.create(loc, res, sum); + builder.create(loc, reductionVal); // End of loop. builder.restoreInsertionPoint(loopEndPt); - mlir::Value resultVal = builder.create(loc, sum); + mlir::Value resultVal = loop.getResult(0); builder.create(loc, resultVal); } +/// Generate function body of the simplified version of FortranASum +/// with signature provided by \p funcOp. The caller is responsible +/// for saving/restoring the original insertion point of \p builder. +/// \p funcOp is expected to be empty on entry to this function. +static void genFortranASumBody(fir::FirOpBuilder &builder, + mlir::func::FuncOp &funcOp) { + // function FortranASum_simplified(arr) + // T, dimension(:) :: arr + // T sum = 0 + // integer iter + // do iter = 0, extent(arr) + // sum = sum + arr[iter] + // end do + // FortranASum_simplified = sum + // end function FortranASum_simplified + auto zero = [](fir::FirOpBuilder builder, mlir::Location loc, + mlir::Type elementType) { + return elementType.isa() + ? builder.createRealConstant(loc, elementType, + llvm::APFloat(0.0)) + : builder.createIntegerConstant(loc, elementType, 0); + }; + + auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, + mlir::Type elementType, mlir::Value elem1, + mlir::Value elem2) -> mlir::Value { + if (elementType.isa()) + return builder.create(loc, elem1, elem2); + if (elementType.isa()) + return builder.create(loc, elem1, elem2); + + llvm_unreachable("unsupported type"); + return {}; + }; + + genReductionLoop(builder, funcOp, zero, genBodyOp); +} + +static void genFortranAMaxvalBody(fir::FirOpBuilder &builder, + mlir::func::FuncOp &funcOp) { + auto init = [](fir::FirOpBuilder builder, mlir::Location loc, + mlir::Type elementType) { + if (auto ty = elementType.dyn_cast()) { + const llvm::fltSemantics &sem = ty.getFloatSemantics(); + return builder.createRealConstant( + loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/true)); + } + unsigned bits = elementType.getIntOrFloatBitWidth(); + int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); + return builder.createIntegerConstant(loc, elementType, minInt); + }; + + auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, + mlir::Type elementType, mlir::Value elem1, + mlir::Value elem2) -> mlir::Value { + if (elementType.isa()) + return builder.create(loc, elem1, elem2); + if (elementType.isa()) + return builder.create(loc, elem1, elem2); + + llvm_unreachable("unsupported type"); + return {}; + }; + genReductionLoop(builder, funcOp, init, genBodyOp); +} + /// Generate function type for the simplified version of FortranADotProduct /// operating on the given \p elementType. static mlir::FunctionType genFortranADotType(fir::FirOpBuilder &builder, @@ -409,7 +467,7 @@ return; } auto typeGenerator = [&type](fir::FirOpBuilder &builder) { - return genFortranASumType(builder, type); + return genNoneBoxType(builder, type); }; mlir::func::FuncOp newFunc = getOrCreateFunction( builder, funcName, typeGenerator, genFortranASumBody); @@ -431,6 +489,7 @@ const mlir::Value &v2 = args[1]; mlir::Location loc = call.getLoc(); fir::FirOpBuilder builder(op, kindMap); + mlir::Type type = call.getResult(0).getType(); if (!type.isa() && !type.isa()) return; @@ -481,6 +540,38 @@ llvm::dbgs() << "\n"); return; } + if (funcName.startswith("_FortranAMaxval")) { + mlir::Operation::operand_range args = call.getArgs(); + const mlir::Value &dim = args[3]; + const mlir::Value &mask = args[4]; + // dim is zero when it is absent, which is an implementation + // detail in the runtime library. + bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask); + unsigned rank = getDimCount(args[0]); + if (dimAndMaskAbsent && rank == 1) { + mlir::Location loc = call.getLoc(); + mlir::Type type; + fir::FirOpBuilder builder(op, kindMap); + if (funcName.endswith("Integer4")) { + type = mlir::IntegerType::get(builder.getContext(), 32); + } else if (funcName.endswith("Real8")) { + type = mlir::FloatType::getF64(builder.getContext()); + } else { + return; + } + auto typeGenerator = [&type](fir::FirOpBuilder &builder) { + return genNoneBoxType(builder, type); + }; + mlir::func::FuncOp newFunc = getOrCreateFunction( + builder, funcName, typeGenerator, genFortranAMaxvalBody); + auto newCall = builder.create( + loc, newFunc, mlir::ValueRange{args[0]}); + call->replaceAllUsesWith(newCall.getResults()); + call->dropAllReferences(); + call->erase(); + return; + } + } } } }); diff --git a/flang/test/Transforms/simplifyintrinsics.fir b/flang/test/Transforms/simplifyintrinsics.fir --- a/flang/test/Transforms/simplifyintrinsics.fir +++ b/flang/test/Transforms/simplifyintrinsics.fir @@ -42,23 +42,19 @@ // CHECK-LABEL: func.func private @_FortranASumInteger4_simplified( // CHECK-SAME: %[[ARR:.*]]: !fir.box) -> i32 attributes {llvm.linkage = #llvm.linkage} { -// CHECK: %[[CI32_0:.*]] = arith.constant 0 : i32 -// CHECK: %[[SUM:.*]] = fir.alloca i32 -// CHECK: fir.store %[[CI32_0]] to %[[SUM]] : !fir.ref // CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index // CHECK: %[[ARR_BOX_I32:.*]] = fir.convert %[[ARR]] : (!fir.box) -> !fir.box> // CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[CINDEX_0]] : (!fir.box>, index) -> (index, index, index) // CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index // CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index -// CHECK: fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] { +// CHECK: %[[CI32_0:.*]] = arith.constant 0 : i32 +// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[SUM:.*]] = %[[CI32_0]]) -> (i32) { // CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER]] : (!fir.box>, index) -> !fir.ref // CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref -// CHECK: %[[SUM_VAL:.*]] = fir.load %[[SUM]] : !fir.ref -// CHECK: %[[NEW_SUM:.*]] = arith.addi %[[ITEM_VAL]], %[[SUM_VAL]] : i32 -// CHECK: fir.store %[[NEW_SUM]] to %[[SUM]] : !fir.ref +// CHECK: %[[NEW_SUM:.*]] = arith.addi %[[ITEM_VAL]], %[[SUM]] : i32 +// CHECK: fir.result %[[NEW_SUM]] : i32 // CHECK: } -// CHECK: %[[RET:.*]] = fir.load %[[SUM]] : !fir.ref -// CHECK: return %[[RET]] : i32 +// CHECK: return %[[RES]] : i32 // CHECK: } // ----- @@ -140,22 +136,18 @@ // CHECK-LABEL: func.func private @_FortranASumReal8_simplified( // CHECK-SAME: %[[ARR:.*]]: !fir.box) -> f64 attributes {llvm.linkage = #llvm.linkage} { -// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK: %[[SUM:.*]] = fir.alloca f64 -// CHECK: fir.store %[[ZERO]] to %[[SUM]] : !fir.ref // CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index // CHECK: %[[ARR_BOX_F64:.*]] = fir.convert %[[ARR]] : (!fir.box) -> !fir.box> // CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F64]], %[[CINDEX_0]] : (!fir.box>, index) -> (index, index, index) // CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index // CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index -// CHECK: fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] { +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[SUM]] = %[[ZERO]]) -> (f64) { // CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_F64]], %[[ITER]] : (!fir.box>, index) -> !fir.ref // CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref -// CHECK: %[[SUM_VAL:.*]] = fir.load %[[SUM]] : !fir.ref -// CHECK: %[[NEW_SUM:.*]] = arith.addf %[[ITEM_VAL]], %[[SUM_VAL]] : f64 -// CHECK: fir.store %[[NEW_SUM]] to %[[SUM]] : !fir.ref +// CHECK: %[[NEW_SUM:.*]] = arith.addf %[[ITEM_VAL]], %[[SUM]] : f64 +// CHECK: fir.result %[[NEW_SUM]] : f64 // CHECK: } -// CHECK: %[[RES:.*]] = fir.load %[[SUM]] : !fir.ref // CHECK: return %[[RES]] : f64 // CHECK: } @@ -312,10 +304,10 @@ // CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %{{.*}} : (!fir.box>, index) -> (index, index, index) // CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index // CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index -// CHECK: fir.do_loop %[[ITER:.*]] = %{{.*}} to %[[EXTENT]] step %[[CINDEX_1]] { +// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %{{.*}} to %[[EXTENT]] step %[[CINDEX_1]] iter_args({{.*}}) -> (i32) { // CHECK: %{{.*}} = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER]] : (!fir.box>, index) -> !fir.ref // CHECK: } -// CHECK: return %{{.*}} : i32 +// CHECK: return %[[RES]] : i32 // CHECK: } // ----- @@ -705,3 +697,116 @@ // CHECK: } // CHECK: return %[[RES]] : f64 // CHECK: } + + +// ----- + +// Call to MAXVAL with 1D I32 array is replaced. +module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.target_triple = "native"} { + func.func @maxval_1d_array_int(%arg0: !fir.ref> {fir.bindc_name = "a"}) -> i32 { + %c10 = arith.constant 10 : index + %0 = fir.alloca i32 {bindc_name = "test_max_2", uniq_name = "_QFtest_max_2Etest_max_2"} + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2 = fir.embox %arg0(%1) : (!fir.ref>, !fir.shape<1>) -> !fir.box> + %3 = fir.absent !fir.box + %c0 = arith.constant 0 : index + %4 = fir.address_of(@_QQcl.2E2F696D61785F322E66393000) : !fir.ref> + %c5_i32 = arith.constant 5 : i32 + %5 = fir.convert %2 : (!fir.box>) -> !fir.box + %6 = fir.convert %4 : (!fir.ref>) -> !fir.ref + %7 = fir.convert %c0 : (index) -> i32 + %8 = fir.convert %3 : (!fir.box) -> !fir.box + %9 = fir.call @_FortranAMaxvalInteger4(%5, %6, %c5_i32, %7, %8) : (!fir.box, !fir.ref, i32, i32, !fir.box) -> i32 + fir.store %9 to %0 : !fir.ref + %10 = fir.load %0 : !fir.ref + return %10 : i32 + } + func.func private @_FortranAMaxvalInteger4(!fir.box, !fir.ref, i32, i32, !fir.box) -> i32 attributes {fir.runtime} + fir.global linkonce @_QQcl.2E2F696D61785F322E66393000 constant : !fir.char<1,13> { + %0 = fir.string_lit "./imax_2.f90\00"(13) : !fir.char<1,13> + fir.has_value %0 : !fir.char<1,13> + } +} + +// CHECK-LABEL: func.func @maxval_1d_array_int( +// CHECK-SAME: %[[A:.*]]: !fir.ref> {fir.bindc_name = "a"}) -> i32 { +// CHECK: %[[SHAPE:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1> +// CHECK: %[[A_BOX_I32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref>, !fir.shape<1>) -> !fir.box> +// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_I32]] : (!fir.box>) -> !fir.box +// CHECK: %[[RES:.*]] = fir.call @_FortranAMaxvalInteger4_simplified(%[[A_BOX_NONE]]) : (!fir.box) -> i32 +// CHECK: return %{{.*}} : i32 +// CHECK: } + +// CHECK-LABEL: func.func private @_FortranAMaxvalInteger4_simplified( +// CHECK-SAME: %[[ARR:.*]]: !fir.box) -> i32 attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index +// CHECK: %[[ARR_BOX_I32:.*]] = fir.convert %[[ARR]] : (!fir.box) -> !fir.box> +// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[CINDEX_0]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index +// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index +// CHECK: %[[CI32_MININT:.*]] = arith.constant -2147483648 : i32 +// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[MAX:.*]] = %[[CI32_MININT]]) -> (i32) { +// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER]] : (!fir.box>, index) -> !fir.ref +// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref +// CHECK: %[[NEW_MAX:.*]] = arith.maxsi %[[ITEM_VAL]], %[[MAX]] : i32 +// CHECK: fir.result %[[NEW_MAX]] : i32 +// CHECK: } +// CHECK: return %[[RES]] : i32 +// CHECK: } + +// ----- + +// Call to MAXVAL with 1D F64 is replaced. +module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.target_triple = "native"} { + func.func @maxval_1d_real(%arg0: !fir.ref> {fir.bindc_name = "a"}) -> f64 { + %c10 = arith.constant 10 : index + %0 = fir.alloca f64 {bindc_name = "maxval_1d_real", uniq_name = "_QFmaxval_1d_realEmaxval_1d_real"} + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2 = fir.embox %arg0(%1) : (!fir.ref>, !fir.shape<1>) -> !fir.box> + %3 = fir.absent !fir.box + %c0 = arith.constant 0 : index + %4 = fir.address_of(@_QQcl.2E2F6973756D5F352E66393000) : !fir.ref> + %c5_i32 = arith.constant 5 : i32 + %5 = fir.convert %2 : (!fir.box>) -> !fir.box + %6 = fir.convert %4 : (!fir.ref>) -> !fir.ref + %7 = fir.convert %c0 : (index) -> i32 + %8 = fir.convert %3 : (!fir.box) -> !fir.box + %9 = fir.call @_FortranAMaxvalReal8(%5, %6, %c5_i32, %7, %8) : (!fir.box, !fir.ref, i32, i32, !fir.box) -> f64 + fir.store %9 to %0 : !fir.ref + %10 = fir.load %0 : !fir.ref + return %10 : f64 + } + func.func private @_FortranAMaxvalReal8(!fir.box, !fir.ref, i32, i32, !fir.box) -> f64 attributes {fir.runtime} + fir.global linkonce @_QQcl.2E2F6973756D5F352E66393000 constant : !fir.char<1,13> { + %0 = fir.string_lit "./imaxval_5.f90\00"(13) : !fir.char<1,13> + fir.has_value %0 : !fir.char<1,13> + } +} + + +// CHECK-LABEL: func.func @maxval_1d_real( +// CHECK-SAME: %[[A:.*]]: !fir.ref> {fir.bindc_name = "a"}) -> f64 { +// CHECK: %[[CINDEX_10:.*]] = arith.constant 10 : index +// CHECK: %[[SHAPE:.*]] = fir.shape %[[CINDEX_10]] : (index) -> !fir.shape<1> +// CHECK: %[[A_BOX_F64:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref>, !fir.shape<1>) -> !fir.box> +// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_F64]] : (!fir.box>) -> !fir.box +// CHECK: %[[RES:.*]] = fir.call @_FortranAMaxvalReal8_simplified(%[[A_BOX_NONE]]) : (!fir.box) -> f64 +// CHECK: return %{{.*}} : f64 +// CHECK: } + +// CHECK-LABEL: func.func private @_FortranAMaxvalReal8_simplified( +// CHECK-SAME: %[[ARR:.*]]: !fir.box) -> f64 attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index +// CHECK: %[[ARR_BOX_F64:.*]] = fir.convert %[[ARR]] : (!fir.box) -> !fir.box> +// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F64]], %[[CINDEX_0]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index +// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index +// CHECK: %[[NEG_DBL_MAX:.*]] = arith.constant -1.7976931348623157E+308 : f64 +// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[MAX]] = %[[NEG_DBL_MAX]]) -> (f64) { +// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_F64]], %[[ITER]] : (!fir.box>, index) -> !fir.ref +// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref +// CHECK: %[[NEW_MAX:.*]] = arith.maxf %[[ITEM_VAL]], %[[MAX]] : f64 +// CHECK: fir.result %[[NEW_MAX]] : f64 +// CHECK: } +// CHECK: return %[[RES]] : f64 +// CHECK: }