diff --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp --- a/flang/lib/Lower/IntrinsicCall.cpp +++ b/flang/lib/Lower/IntrinsicCall.cpp @@ -231,18 +231,18 @@ // Handle required vector arguments mlir::Value vectorA = fir::getBase(args[0]); mlir::Value vectorB = fir::getBase(args[1]); + // Result type is used for picking appropriate runtime function. + mlir::Type eleTy = resultType; - mlir::Type eleTy = fir::dyn_cast_ptrOrBoxEleTy(vectorA.getType()) - .cast() - .getEleTy(); if (fir::isa_complex(eleTy)) { mlir::Value result = builder.createTemporary(loc, eleTy); func(builder, loc, vectorA, vectorB, result); return builder.create(loc, result); } - auto resultBox = builder.create( - loc, fir::BoxType::get(builder.getI1Type())); + // This operation is only used to pass the result type + // information to the DotProduct generator. + auto resultBox = builder.create(loc, fir::BoxType::get(eleTy)); return func(builder, loc, vectorA, vectorB, resultBox); } diff --git a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp --- a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp +++ b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp @@ -799,9 +799,10 @@ mlir::Value vectorBBox, mlir::Value resultBox) { mlir::func::FuncOp func; - auto ty = vectorABox.getType(); - auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty); - auto eleTy = arrTy.cast().getEleTy(); + // For complex data types, resultBox is !fir.ref>, + // otherwise it is !fir.box. + auto ty = resultBox.getType(); + auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(ty); if (eleTy.isF16() || eleTy.isBF16()) TODO(loc, "half-precision DOTPRODUCT"); diff --git a/flang/test/Lower/Intrinsics/dot_product.f90 b/flang/test/Lower/Intrinsics/dot_product.f90 --- a/flang/test/Lower/Intrinsics/dot_product.f90 +++ b/flang/test/Lower/Intrinsics/dot_product.f90 @@ -245,3 +245,46 @@ ! CHECK-DAG: %[[res:.*]] = fir.call @_FortranADotProductLogical(%[[x_conv]], %[[y_conv]], %{{[0-9]+}}, %{{.*}}) : (!fir.box, !fir.box, !fir.ref, i32) -> i1 z = dot_product(x,y) end subroutine + +! CHECK-LABEL: dot_product_mixed_int_real +! CHECK-SAME: %[[x:arg0]]: !fir.box> +! CHECK-SAME: %[[y:arg1]]: !fir.box> +! CHECK-SAME: %[[z:arg2]]: !fir.box> +subroutine dot_product_mixed_int_real(x, y, z) + integer, dimension(1:) :: x + real, dimension(1:) :: y, z + ! CHECK-DAG: %[[x_conv:.*]] = fir.convert %[[x]] : (!fir.box>) -> !fir.box + ! CHECK-DAG: %[[y_conv:.*]] = fir.convert %[[y]] : (!fir.box>) -> !fir.box + ! CHECK-DAG: %[[res:.*]] = fir.call @_FortranADotProductReal4(%[[x_conv]], %[[y_conv]], %{{[0-9]+}}, %{{.*}}) : (!fir.box, !fir.box, !fir.ref, i32) -> f32 + z = dot_product(x,y) +end subroutine + +! CHECK-LABEL: dot_product_mixed_int_complex +! CHECK-SAME: %[[x:arg0]]: !fir.box> +! CHECK-SAME: %[[y:arg1]]: !fir.box>> +! CHECK-SAME: %[[z:arg2]]: !fir.box>> +subroutine dot_product_mixed_int_complex(x, y, z) + integer, dimension(1:) :: x + complex, dimension(1:) :: y, z + ! CHECK-DAG: %[[res:.*]] = fir.alloca !fir.complex<4> + ! CHECK-DAG: %[[res_conv:.*]] = fir.convert %[[res]] : (!fir.ref>) -> !fir.ref> + ! CHECK-DAG: %[[x_conv:.*]] = fir.convert %[[x]] : (!fir.box>) -> !fir.box + ! CHECK-DAG: %[[y_conv:.*]] = fir.convert %[[y]] : (!fir.box>>) -> !fir.box + ! CHECK-DAG: fir.call @_FortranACppDotProductComplex4(%[[res_conv]], %[[x_conv]], %[[y_conv]], %{{[0-9]+}}, %{{.*}}) : (!fir.ref>, !fir.box, !fir.box, !fir.ref, i32) -> none + z = dot_product(x,y) +end subroutine + +! CHECK-LABEL: dot_product_mixed_real_complex +! CHECK-SAME: %[[x:arg0]]: !fir.box> +! CHECK-SAME: %[[y:arg1]]: !fir.box>> +! CHECK-SAME: %[[z:arg2]]: !fir.box>> +subroutine dot_product_mixed_real_complex(x, y, z) + real, dimension(1:) :: x + complex, dimension(1:) :: y, z + ! CHECK-DAG: %[[res:.*]] = fir.alloca !fir.complex<4> + ! CHECK-DAG: %[[res_conv:.*]] = fir.convert %[[res]] : (!fir.ref>) -> !fir.ref> + ! CHECK-DAG: %[[x_conv:.*]] = fir.convert %[[x]] : (!fir.box>) -> !fir.box + ! CHECK-DAG: %[[y_conv:.*]] = fir.convert %[[y]] : (!fir.box>>) -> !fir.box + ! CHECK-DAG: fir.call @_FortranACppDotProductComplex4(%[[res_conv]], %[[x_conv]], %[[y_conv]], %{{[0-9]+}}, %{{.*}}) : (!fir.ref>, !fir.box, !fir.box, !fir.ref, i32) -> none + z = dot_product(x,y) +end subroutine diff --git a/flang/unittests/Optimizer/Builder/Runtime/ReductionTest.cpp b/flang/unittests/Optimizer/Builder/Runtime/ReductionTest.cpp --- a/flang/unittests/Optimizer/Builder/Runtime/ReductionTest.cpp +++ b/flang/unittests/Optimizer/Builder/Runtime/ReductionTest.cpp @@ -202,7 +202,8 @@ mlir::Type refSeqTy = fir::ReferenceType::get(seqTy); mlir::Value a = builder.create(loc, refSeqTy); mlir::Value b = builder.create(loc, refSeqTy); - mlir::Value result = builder.create(loc, seqTy); + mlir::Value result = + builder.create(loc, fir::ReferenceType::get(eleTy)); mlir::Value prod = fir::runtime::genDotProduct(builder, loc, a, b, result); if (fir::isa_complex(eleTy)) checkCallOpFromResultBox(result, fctName, 3);