diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp --- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp +++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp @@ -31,11 +31,14 @@ #include "flang/Optimizer/Support/FIRContext.h" #include "flang/Optimizer/Transforms/Passes.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/Optional.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "flang-simplify-intrinsics" @@ -159,8 +162,13 @@ /// with signature provided by \p funcOp. The caller is responsible /// for saving/restoring the original insertion point of \p builder. /// \p funcOp is expected to be empty on entry to this function. +/// \p arg1ElementTy and \p arg2ElementTy specify elements types +/// of the underlying array objects - they are used to generate proper +/// element accesses. static void genFortranADotBody(fir::FirOpBuilder &builder, - mlir::func::FuncOp &funcOp) { + mlir::func::FuncOp &funcOp, + mlir::Type arg1ElementTy, + mlir::Type arg2ElementTy) { // function FortranADotProduct_simplified(arr1, arr2) // T, dimension(:) :: arr1, arr2 // T product = 0 @@ -171,14 +179,15 @@ // FortranADotProduct_simplified = product // end function FortranADotProduct_simplified auto loc = mlir::UnknownLoc::get(builder.getContext()); - mlir::Type elementType = funcOp.getResultTypes()[0]; + mlir::Type resultElementType = funcOp.getResultTypes()[0]; builder.setInsertionPointToEnd(funcOp.addEntryBlock()); mlir::IndexType idxTy = builder.getIndexType(); - mlir::Value zero = elementType.isa() - ? builder.createRealConstant(loc, elementType, 0.0) - : builder.createIntegerConstant(loc, elementType, 0); + mlir::Value zero = + resultElementType.isa() + ? builder.createRealConstant(loc, resultElementType, 0.0) + : builder.createIntegerConstant(loc, resultElementType, 0); mlir::Block::BlockArgListType args = funcOp.front().getArguments(); mlir::Value arg1 = args[0]; @@ -187,10 +196,12 @@ mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()}; - mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType); - mlir::Type boxArrTy = fir::BoxType::get(arrTy); - mlir::Value array1 = builder.create(loc, boxArrTy, arg1); - mlir::Value array2 = builder.create(loc, boxArrTy, arg2); + mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy); + mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1); + mlir::Value array1 = builder.create(loc, boxArrTy1, arg1); + mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy); + mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2); + mlir::Value array2 = builder.create(loc, boxArrTy2, arg2); // This version takes the loop trip count from the first argument. // If the first argument's box has unknown (at compilation time) // extent, then it may be better to take the extent from the second @@ -216,19 +227,25 @@ mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint(); builder.setInsertionPointToStart(loop.getBody()); - mlir::Type eleRefTy = builder.getRefType(elementType); + mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy); mlir::Value index = loop.getInductionVar(); mlir::Value addr1 = - builder.create(loc, eleRefTy, array1, index); + builder.create(loc, eleRef1Ty, array1, index); mlir::Value elem1 = builder.create(loc, addr1); + // Convert to the result type. + elem1 = builder.create(loc, resultElementType, elem1); + + mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy); mlir::Value addr2 = - builder.create(loc, eleRefTy, array2, index); + builder.create(loc, eleRef2Ty, array2, index); mlir::Value elem2 = builder.create(loc, addr2); + // Convert to the result type. + elem2 = builder.create(loc, resultElementType, elem2); - if (elementType.isa()) + if (resultElementType.isa()) sumVal = builder.create( loc, builder.create(loc, elem1, elem2), sumVal); - else if (elementType.isa()) + else if (resultElementType.isa()) sumVal = builder.create( loc, builder.create(loc, elem1, elem2), sumVal); else @@ -317,6 +334,29 @@ return 0; } +/// Given the call operation's box argument \p val, discover +/// the element type of the underlying array object. +/// \returns the element type or llvm::None if the type cannot +/// be reliably found. +/// We expect that the argument is a result of fir.convert +/// with the destination type of !fir.box. +static llvm::Optional getArgElementType(mlir::Value val) { + mlir::Operation *defOp; + do { + defOp = val.getDefiningOp(); + // Analyze only sequences of convert operations. + if (!mlir::isa(defOp)) + return llvm::None; + val = defOp->getOperand(0); + // The convert operation is expected to convert from one + // box type to another box type. + auto boxType = val.getType().cast(); + auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType); + if (!elementType.isa()) + return elementType; + } while (true); +} + void SimplifyIntrinsicsPass::runOnOperation() { LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n"); mlir::ModuleOp module = getOperation(); @@ -380,11 +420,42 @@ if (!type.isa() && !type.isa()) return; + // Try to find the element types of the boxed arguments. + auto arg1Type = getArgElementType(v1); + auto arg2Type = getArgElementType(v2); + + if (!arg1Type || !arg2Type) + return; + + // Support only floating point and integer arguments + // now (e.g. logical is skipped here). + if (!arg1Type->isa() && + !arg1Type->isa()) + return; + if (!arg2Type->isa() && + !arg2Type->isa()) + return; + auto typeGenerator = [&type](fir::FirOpBuilder &builder) { return genFortranADotType(builder, type); }; + auto bodyGenerator = [&arg1Type, + &arg2Type](fir::FirOpBuilder &builder, + mlir::func::FuncOp &funcOp) { + genFortranADotBody(builder, funcOp, *arg1Type, *arg2Type); + }; + + // Suffix the function name with the element types + // of the arguments. + std::string typedFuncName(funcName); + llvm::raw_string_ostream nameOS(typedFuncName); + nameOS << "_"; + arg1Type->print(nameOS); + nameOS << "_"; + arg2Type->print(nameOS); + mlir::func::FuncOp newFunc = getOrCreateFunction( - builder, funcName, typeGenerator, genFortranADotBody); + builder, typedFuncName, typeGenerator, bodyGenerator); auto newCall = builder.create(loc, newFunc, mlir::ValueRange{v1, v2}); call->replaceAllUsesWith(newCall.getResults()); diff --git a/flang/test/Transforms/simplifyintrinsics.fir b/flang/test/Transforms/simplifyintrinsics.fir --- a/flang/test/Transforms/simplifyintrinsics.fir +++ b/flang/test/Transforms/simplifyintrinsics.fir @@ -344,15 +344,15 @@ // CHECK: %[[RESLOC:.*]] = fir.alloca f32 {bindc_name = "dot", uniq_name = "_QFdotEdot"} // CHECK: %[[ACAST:.*]] = fir.convert %[[A]] : (!fir.box>) -> !fir.box // CHECK: %[[BCAST:.*]] = fir.convert %[[B]] : (!fir.box>) -> !fir.box -// CHECK: %[[RES:.*]] = fir.call @_FortranADotProductReal4_simplified(%[[ACAST]], %[[BCAST]]) : (!fir.box, !fir.box) -> f32 +// CHECK: %[[RES:.*]] = fir.call @_FortranADotProductReal4_f32_f32_simplified(%[[ACAST]], %[[BCAST]]) : (!fir.box, !fir.box) -> f32 // CHECK: fir.store %[[RES]] to %[[RESLOC]] : !fir.ref // CHECK: %[[RET:.*]] = fir.load %[[RESLOC]] : !fir.ref // CHECK: return %[[RET]] : f32 // CHECK: } -// CHECK-LABEL: func.func private @_FortranADotProductReal4_simplified( -// CHECK-SAME: %[[A:.*]]: !fir.box, -// CHECK-SAME: %[[B:.*]]: !fir.box) -> f32 attributes {llvm.linkage = #llvm.linkage} { +// CHECK-LABEL: func.func private @_FortranADotProductReal4_f32_f32_simplified( +// CHECK-SAME: %[[A:.*]]: !fir.box, +// CHECK-SAME: %[[B:.*]]: !fir.box) -> f32 attributes {llvm.linkage = #llvm.linkage} { // CHECK: %[[FZERO:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[IZERO:.*]] = arith.constant 0 : index // CHECK: %[[ACAST:.*]] = fir.convert %[[A]] : (!fir.box) -> !fir.box> @@ -363,9 +363,11 @@ // CHECK: %[[RES:.*]] = fir.do_loop %[[IDX:.*]] = %[[IZERO]] to %[[LEN]] step %[[IONE]] iter_args(%[[SUM:.*]] = %[[FZERO]]) -> (f32) { // CHECK: %[[ALOC:.*]] = fir.coordinate_of %[[ACAST]], %[[IDX]] : (!fir.box>, index) -> !fir.ref // CHECK: %[[AVAL:.*]] = fir.load %[[ALOC]] : !fir.ref +// CHECK: %[[AVALCAST:.*]] = fir.convert %[[AVAL]] : (f32) -> f32 // CHECK: %[[BLOC:.*]] = fir.coordinate_of %[[BCAST]], %[[IDX]] : (!fir.box>, index) -> !fir.ref // CHECK: %[[BVAL:.*]] = fir.load %[[BLOC]] : !fir.ref -// CHECK: %[[MUL:.*]] = arith.mulf %[[AVAL]], %[[BVAL]] : f32 +// CHECK: %[[BVALCAST:.*]] = fir.convert %[[BVAL]] : (f32) -> f32 +// CHECK: %[[MUL:.*]] = arith.mulf %[[AVALCAST]], %[[BVALCAST]] : f32 // CHECK: %[[NEWSUM:.*]] = arith.addf %[[MUL]], %[[SUM]] : f32 // CHECK: fir.result %[[NEWSUM]] : f32 // CHECK: } @@ -479,15 +481,15 @@ // CHECK: %[[RESLOC:.*]] = fir.alloca i32 {bindc_name = "dot", uniq_name = "_QFdotEdot"} // CHECK: %[[ACAST:.*]] = fir.convert %[[A]] : (!fir.box>) -> !fir.box // CHECK: %[[BCAST:.*]] = fir.convert %[[B]] : (!fir.box>) -> !fir.box -// CHECK: %[[RES:.*]] = fir.call @_FortranADotProductInteger4_simplified(%[[ACAST]], %[[BCAST]]) : (!fir.box, !fir.box) -> i32 +// CHECK: %[[RES:.*]] = fir.call @_FortranADotProductInteger4_i32_i32_simplified(%[[ACAST]], %[[BCAST]]) : (!fir.box, !fir.box) -> i32 // CHECK: fir.store %[[RES]] to %[[RESLOC]] : !fir.ref // CHECK: %[[RET:.*]] = fir.load %[[RESLOC]] : !fir.ref // CHECK: return %[[RET]] : i32 // CHECK: } -// CHECK-LABEL: func.func private @_FortranADotProductInteger4_simplified( -// CHECK-SAME: %[[A:.*]]: !fir.box, -// CHECK-SAME: %[[B:.*]]: !fir.box) -> i32 attributes {llvm.linkage = #llvm.linkage} { +// CHECK-LABEL: func.func private @_FortranADotProductInteger4_i32_i32_simplified( +// CHECK-SAME: %[[A:.*]]: !fir.box, +// CHECK-SAME: %[[B:.*]]: !fir.box) -> i32 attributes {llvm.linkage = #llvm.linkage} { // CHECK: %[[I32ZERO:.*]] = arith.constant 0 : i32 // CHECK: %[[IZERO:.*]] = arith.constant 0 : index // CHECK: %[[ACAST:.*]] = fir.convert %[[A]] : (!fir.box) -> !fir.box> @@ -498,9 +500,11 @@ // CHECK: %[[RES:.*]] = fir.do_loop %[[IDX:.*]] = %[[IZERO]] to %[[LEN]] step %[[IONE]] iter_args(%[[SUM:.*]] = %[[I32ZERO]]) -> (i32) { // CHECK: %[[ALOC:.*]] = fir.coordinate_of %[[ACAST]], %[[IDX]] : (!fir.box>, index) -> !fir.ref // CHECK: %[[AVAL:.*]] = fir.load %[[ALOC]] : !fir.ref +// CHECK: %[[AVALCAST:.*]] = fir.convert %[[AVAL]] : (i32) -> i32 // CHECK: %[[BLOC:.*]] = fir.coordinate_of %[[BCAST]], %[[IDX]] : (!fir.box>, index) -> !fir.ref // CHECK: %[[BVAL:.*]] = fir.load %[[BLOC]] : !fir.ref -// CHECK: %[[MUL:.*]] = arith.muli %[[AVAL]], %[[BVAL]] : i32 +// CHECK: %[[BVALCAST:.*]] = fir.convert %[[BVAL]] : (i32) -> i32 +// CHECK: %[[MUL:.*]] = arith.muli %[[AVALCAST]], %[[BVALCAST]] : i32 // CHECK: %[[NEWSUM:.*]] = arith.addi %[[MUL]], %[[SUM]] : i32 // CHECK: fir.result %[[NEWSUM]] : i32 // CHECK: } @@ -587,3 +591,63 @@ // CHECK-SAME: %[[A:.*]]: !fir.box> {fir.bindc_name = "a"}, // CHECK-SAME: %[[B:.*]]: !fir.box> {fir.bindc_name = "b"}) -> i64 { // CHECK-NOT: call{{.*}}_FortranADotProductInteger8( + +// ----- + +// Test mixed types, e.g. when _FortranADotProductReal8 is called +// with and arguments. The loaded elements must be converted +// to the result type REAL(8) before the computations. + +func.func @dot_f64_f32(%arg0: !fir.box> {fir.bindc_name = "a"}, %arg1: !fir.box> {fir.bindc_name = "b"}) -> f64 { + %0 = fir.alloca f64 {bindc_name = "dot", uniq_name = "_QFdotEdot"} + %1 = fir.address_of(@_QQcl.2E2F646F742E66393000) : !fir.ref> + %c3_i32 = arith.constant 3 : i32 + %2 = fir.convert %arg0 : (!fir.box>) -> !fir.box + %3 = fir.convert %arg1 : (!fir.box>) -> !fir.box + %4 = fir.convert %1 : (!fir.ref>) -> !fir.ref + %5 = fir.call @_FortranADotProductReal8(%2, %3, %4, %c3_i32) : (!fir.box, !fir.box, !fir.ref, i32) -> f64 + fir.store %5 to %0 : !fir.ref + %6 = fir.load %0 : !fir.ref + return %6 : f64 +} +func.func private @_FortranADotProductReal4(!fir.box, !fir.box, !fir.ref, i32) -> f32 attributes {fir.runtime} +fir.global linkonce @_QQcl.2E2F646F742E66393000 constant : !fir.char<1,10> { + %0 = fir.string_lit "./dot.f90\00"(10) : !fir.char<1,10> + fir.has_value %0 : !fir.char<1,10> +} + +// CHECK-LABEL: func.func @dot_f64_f32( +// CHECK-SAME: %[[A:.*]]: !fir.box> {fir.bindc_name = "a"}, +// CHECK-SAME: %[[B:.*]]: !fir.box> {fir.bindc_name = "b"}) -> f64 { +// CHECK: %[[RESLOC:.*]] = fir.alloca f64 {bindc_name = "dot", uniq_name = "_QFdotEdot"} +// CHECK: %[[ACAST:.*]] = fir.convert %[[A]] : (!fir.box>) -> !fir.box +// CHECK: %[[BCAST:.*]] = fir.convert %[[B]] : (!fir.box>) -> !fir.box +// CHECK: %[[RES:.*]] = fir.call @_FortranADotProductReal8_f64_f32_simplified(%[[ACAST]], %[[BCAST]]) : (!fir.box, !fir.box) -> f64 +// CHECK: fir.store %[[RES]] to %[[RESLOC]] : !fir.ref +// CHECK: %[[RET:.*]] = fir.load %[[RESLOC]] : !fir.ref +// CHECK: return %[[RET]] : f64 +// CHECK: } + +// CHECK-LABEL: func.func private @_FortranADotProductReal8_f64_f32_simplified( +// CHECK-SAME: %[[A:.*]]: !fir.box, +// CHECK-SAME: %[[B:.*]]: !fir.box) -> f64 attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[FZERO:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[IZERO:.*]] = arith.constant 0 : index +// CHECK: %[[ACAST:.*]] = fir.convert %[[A]] : (!fir.box) -> !fir.box> +// CHECK: %[[BCAST:.*]] = fir.convert %[[B]] : (!fir.box) -> !fir.box> +// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ACAST]], %[[IZERO]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[IONE:.*]] = arith.constant 1 : index +// CHECK: %[[LEN:.*]] = arith.subi %[[DIMS]]#1, %[[IONE]] : index +// CHECK: %[[RES:.*]] = fir.do_loop %[[IDX:.*]] = %[[IZERO]] to %[[LEN]] step %[[IONE]] iter_args(%[[SUM:.*]] = %[[FZERO]]) -> (f64) { +// CHECK: %[[ALOC:.*]] = fir.coordinate_of %[[ACAST]], %[[IDX]] : (!fir.box>, index) -> !fir.ref +// CHECK: %[[AVAL:.*]] = fir.load %[[ALOC]] : !fir.ref +// CHECK: %[[AVALCAST:.*]] = fir.convert %[[AVAL]] : (f64) -> f64 +// CHECK: %[[BLOC:.*]] = fir.coordinate_of %[[BCAST]], %[[IDX]] : (!fir.box>, index) -> !fir.ref +// CHECK: %[[BVAL:.*]] = fir.load %[[BLOC]] : !fir.ref +// CHECK: %[[BVALCAST:.*]] = fir.convert %[[BVAL]] : (f32) -> f64 +// CHECK: %[[MUL:.*]] = arith.mulf %[[AVALCAST]], %[[BVALCAST]] : f64 +// CHECK: %[[NEWSUM:.*]] = arith.addf %[[MUL]], %[[SUM]] : f64 +// CHECK: fir.result %[[NEWSUM]] : f64 +// CHECK: } +// CHECK: return %[[RES]] : f64 +// CHECK: }