diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td --- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td @@ -430,6 +430,29 @@ let hasVerifier = 1; } +def hlifr_DotProductOp : hlfir_Op<"dot_product", + [DeclareOpInterfaceMethods]> { + let summary = "DOT_PRODUCT transformational intrinsic"; + let description = [{ + Dot product of two vectors + }]; + + let arguments = (ins + AnyFortranNumericalOrLogicalArrayObject:$lhs, + AnyFortranNumericalOrLogicalArrayObject:$rhs, + DefaultValuedAttr:$fastmath + ); + + let results = (outs AnyFortranValue); + + let assemblyFormat = [{ + $lhs $rhs attr-dict `:` functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + def hlfir_MatmulOp : hlfir_Op<"matmul", [DeclareOpInterfaceMethods]> { let summary = "MATMUL transformational intrinsic"; diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp --- a/flang/lib/Lower/ConvertCall.cpp +++ b/flang/lib/Lower/ConvertCall.cpp @@ -1488,6 +1488,15 @@ return buildReductionIntrinsic(loweredActuals, loc, builder, callContext, buildAllOperation, false); } + if (intrinsicName == "dot_product") { + llvm::SmallVector operands = getOperandVector(loweredActuals); + mlir::Type resultTy = + computeResultType(operands[0], *callContext.resultType); + hlfir::DotProductOp dotProductOp = builder.create( + loc, resultTy, operands[0], operands[1]); + + return {hlfir::EntityWithAttributes{dotProductOp.getResult()}}; + } // TODO add hlfir operations for other transformational intrinsics here diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -669,6 +669,52 @@ return verifyNumericalReductionOp(this); } +//===----------------------------------------------------------------------===// +// DotProductOp +//===----------------------------------------------------------------------===// + +mlir::LogicalResult hlfir::DotProductOp::verify() { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); + fir::SequenceType lhsTy = + hlfir::getFortranElementOrSequenceType(lhs.getType()) + .cast(); + fir::SequenceType rhsTy = + hlfir::getFortranElementOrSequenceType(rhs.getType()) + .cast(); + llvm::ArrayRef lhsShape = lhsTy.getShape(); + llvm::ArrayRef rhsShape = rhsTy.getShape(); + int64_t lhsSize = lhsShape[0]; + int64_t rhsSize = rhsShape[0]; + std::size_t lhsRank = lhsShape.size(); + std::size_t rhsRank = rhsShape.size(); + mlir::Type lhsEleTy = lhsTy.getEleTy(); + mlir::Type rhsEleTy = rhsTy.getEleTy(); + mlir::Type resultTy = getResult().getType(); + + if ((lhsRank != 1) || (rhsRank != 1)) + return emitOpError("both arrays must have rank 1"); + + if (lhsSize != rhsSize) + return emitOpError("both arrays must have the same size"); + + if (mlir::isa(lhsEleTy) != + mlir::isa(rhsEleTy)) + return emitOpError("if one array is logical, so should the other be"); + + if (mlir::isa(lhsEleTy) != + mlir::isa(resultTy)) + return emitOpError("the result type should be a logical only if the " + "argument types are logical"); + + if (!hlfir::isFortranScalarNumericalType(resultTy) && + !mlir::isa(resultTy)) + return emitOpError( + "the result must be of scalar numerical or logical type"); + + return mlir::success(); +} + //===----------------------------------------------------------------------===// // MatmulOp //===----------------------------------------------------------------------===// diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp --- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp @@ -277,6 +277,38 @@ } }; +struct DotProductOpConversion + : public HlfirIntrinsicConversion { + using HlfirIntrinsicConversion::HlfirIntrinsicConversion; + + mlir::LogicalResult + matchAndRewrite(hlfir::DotProductOp dotProduct, + mlir::PatternRewriter &rewriter) const override { + fir::KindMapping kindMapping{rewriter.getContext()}; + fir::FirOpBuilder builder{rewriter, kindMapping}; + const mlir::Location &loc = dotProduct->getLoc(); + + mlir::Value lhs = dotProduct.getLhs(); + mlir::Value rhs = dotProduct.getRhs(); + llvm::SmallVector inArgs; + inArgs.push_back({lhs, lhs.getType()}); + inArgs.push_back({rhs, rhs.getType()}); + + auto *argLowering = fir::getIntrinsicArgumentLowering("dot_product"); + llvm::SmallVector args = + lowerArguments(dotProduct, inArgs, rewriter, argLowering); + + mlir::Type scalarResultType = + hlfir::getFortranElementType(dotProduct.getType()); + + auto [resultExv, mustBeFreed] = fir::genIntrinsicCall( + builder, loc, "dot_product", scalarResultType, args); + + processReturnValue(dotProduct, resultExv, mustBeFreed, builder, rewriter); + return mlir::success(); + } +}; + class TransposeOpConversion : public HlfirIntrinsicConversion { using HlfirIntrinsicConversion::HlfirIntrinsicConversion; @@ -356,14 +388,15 @@ mlir::RewritePatternSet patterns(context); patterns.insert(context); + ProductOpConversion, TransposeOpConversion, + DotProductOpConversion>(context); mlir::ConversionTarget target(*context); target.addLegalDialect(); target.addIllegalOp(); + hlfir::AllOp, hlfir::DotProductOp>(); target.markUnknownOpDynamicallyLegal( [](mlir::Operation *) { return true; }); if (mlir::failed( diff --git a/flang/test/HLFIR/dot_product-lowering.fir b/flang/test/HLFIR/dot_product-lowering.fir new file mode 100644 --- /dev/null +++ b/flang/test/HLFIR/dot_product-lowering.fir @@ -0,0 +1,80 @@ +// Test hlfir.matmul operation lowering to fir runtime call +// RUN: fir-opt %s -lower-hlfir-intrinsics | FileCheck %s + +func.func @_QPdot_product1(%arg0: !fir.box> {fir.bindc_name = "lhs"}, %arg1: !fir.box> {fir.bindc_name = "rhs"}, %arg2: !fir.ref {fir.bindc_name = "res"}) { + %0:2 = hlfir.declare %arg0 {uniq_name = "_QFdot_product1Elhs"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %1:2 = hlfir.declare %arg2 {uniq_name = "_QFdot_product1Eres"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %2:2 = hlfir.declare %arg1 {uniq_name = "_QFdot_product1Erhs"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %3 = hlfir.dot_product %0#0 %2#0 {fastmath = #arith.fastmath} : (!fir.box>, !fir.box>) -> i32 + hlfir.assign %3 to %1#0 : i32, !fir.ref + return +} +// CHECK-LABEL: func.func @_QPdot_product1( +// CHECK: %[[ARG0:.*]]: !fir.box> {fir.bindc_name = "lhs"} +// CHECK: %[[ARG1:.*]]: !fir.box> {fir.bindc_name = "rhs"} +// CHECK: %[[ARG2:.*]]: !fir.ref {fir.bindc_name = "res"} +// CHECK-DAG: %[[LHS_VAR:.*]]:2 = hlfir.declare %[[ARG0]] +// CHECK-DAG: %[[RHS_VAR:.*]]:2 = hlfir.declare %[[ARG1]] +// CHECK-DAG: %[[RES_VAR:.*]]:2 = hlfir.declare %[[ARG2]] + +// CHECK-DAG: %[[LHS_ARG:.*]] = fir.convert %[[LHS_VAR]]#1 : (!fir.box>) -> !fir.box +// CHECK-DAG: %[[RHS_ARG:.*]] = fir.convert %[[RHS_VAR]]#1 : (!fir.box>) -> !fir.box + +// CHECK: %[[NONE:.*]] = fir.call @_FortranADotProductInteger4(%[[LHS_ARG]], %[[RHS_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]]) +// CHECK-NEXT: hlfir.assign %[[NONE]] to %[[RES_VAR]]#0 : i32, !fir.ref +// CHECK-NEXT: return +// CHECK-NEXT: } + +func.func @_QPdot_product2(%arg0: !fir.box>> {fir.bindc_name = "lhs"}, %arg1: !fir.box>> {fir.bindc_name = "rhs"}, %arg2: !fir.ref> {fir.bindc_name = "res"}) { + %0:2 = hlfir.declare %arg0 {uniq_name = "_QFdot_product2Elhs"} : (!fir.box>>) -> (!fir.box>>, !fir.box>>) + %1:2 = hlfir.declare %arg2 {uniq_name = "_QFdot_product2Eres"} : (!fir.ref>) -> (!fir.ref>, !fir.ref>) + %2:2 = hlfir.declare %arg1 {uniq_name = "_QFdot_product2Erhs"} : (!fir.box>>) -> (!fir.box>>, !fir.box>>) + %3 = hlfir.dot_product %0#0 %2#0 {fastmath = #arith.fastmath} : (!fir.box>>, !fir.box>>) -> !fir.logical<4> + hlfir.assign %3 to %1#0 : !fir.logical<4>, !fir.ref> + return +} +// CHECK-LABEL: func.func @_QPdot_product2( +// CHECK: %[[ARG0:.*]]: !fir.box>> {fir.bindc_name = "lhs"} +// CHECK: %[[ARG1:.*]]: !fir.box>> {fir.bindc_name = "rhs"} +// CHECK: %[[ARG2:.*]]: !fir.ref> {fir.bindc_name = "res"} +// CHECK-DAG: %[[LHS_VAR:.*]]:2 = hlfir.declare %[[ARG0]] +// CHECK-DAG: %[[RHS_VAR:.*]]:2 = hlfir.declare %[[ARG1]] +// CHECK-DAG: %[[RES_VAR:.*]]:2 = hlfir.declare %[[ARG2]] + +// CHECK-DAG: %[[LHS_ARG:.*]] = fir.convert %[[LHS_VAR]]#1 : (!fir.box>>) -> !fir.box +// CHECK-DAG: %[[RHS_ARG:.*]] = fir.convert %[[RHS_VAR]]#1 : (!fir.box>>) -> !fir.box + +// CHECK: %[[NONE:.*]] = fir.call @_FortranADotProductLogical(%[[LHS_ARG]], %[[RHS_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]]) +// CHECK-NEXT: hlfir.assign %[[NONE]] to %[[RES_VAR]]#0 : i1, !fir.ref> +// CHECK-NEXT: return +// CHECK-NEXT: } + + func.func @_QPdot_product3(%arg0: !fir.ref> {fir.bindc_name = "lhs"}, %arg1: !fir.ref> {fir.bindc_name = "rhs"}, %arg2: !fir.ref {fir.bindc_name = "res"}) { + %c5 = arith.constant 5 : index + %0 = fir.shape %c5 : (index) -> !fir.shape<1> + %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFdot_product3Elhs"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) + %2:2 = hlfir.declare %arg2 {uniq_name = "_QFdot_product3Eres"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %c5_0 = arith.constant 5 : index + %3 = fir.shape %c5_0 : (index) -> !fir.shape<1> + %4:2 = hlfir.declare %arg1(%3) {uniq_name = "_QFdot_product3Erhs"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) + %5 = hlfir.dot_product %1#0 %4#0 {fastmath = #arith.fastmath} : (!fir.ref>, !fir.ref>) -> i32 + hlfir.assign %5 to %2#0 : i32, !fir.ref + return +} +// CHECK-LABEL: func.func @_QPdot_product3( +// CHECK: %[[ARG0:.*]]: !fir.ref> {fir.bindc_name = "lhs"} +// CHECK: %[[ARG1:.*]]: !fir.ref> {fir.bindc_name = "rhs"} +// CHECK: %[[ARG2:.*]]: !fir.ref {fir.bindc_name = "res"} +// CHECK-DAG: %[[LHS_VAR:.*]]:2 = hlfir.declare %[[ARG0]] +// CHECK-DAG: %[[RHS_VAR:.*]]:2 = hlfir.declare %[[ARG1]] +// CHECK-DAG: %[[RES_VAR:.*]]:2 = hlfir.declare %[[ARG2]] + +// CHECK-DAG: %[[LHS_BOX:.*]] = fir.embox %[[LHS_VAR]]#1 +// CHECK-DAG: %[[RHS_BOX:.*]] = fir.embox %[[RHS_VAR]]#1 +// CHECK-DAG: %[[LHS_ARG:.*]] = fir.convert %[[LHS_BOX]] : (!fir.box>) -> !fir.box +// CHECK-DAG: %[[RHS_ARG:.*]] = fir.convert %[[RHS_BOX]] : (!fir.box>) -> !fir.box + +// CHECK: %[[NONE:.*]] = fir.call @_FortranADotProductInteger4(%[[LHS_ARG]], %[[RHS_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]]) +// CHECK-NEXT: hlfir.assign %[[NONE]] to %[[RES_VAR]]#0 : i32, !fir.ref +// CHECK-NEXT: return +// CHECK-NEXT: } \ No newline at end of file diff --git a/flang/test/HLFIR/dot_product.fir b/flang/test/HLFIR/dot_product.fir new file mode 100644 --- /dev/null +++ b/flang/test/HLFIR/dot_product.fir @@ -0,0 +1,51 @@ +// Test hlfir.dot_product operation parse, verify (no errors), and unparse + +// RUN: fir-opt %s | fir-opt | FileCheck %s + +// arguments are expressions of known shape +func.func @dot_product0(%arg0: !hlfir.expr<2xi32>, %arg1: !hlfir.expr<2xi32>) { + %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<2xi32>, !hlfir.expr<2xi32>) -> i32 + return +} +// CHECK-LABEL: func.func @dot_product0 +// CHECK: %[[ARG0:.*]]: !hlfir.expr<2xi32>, +// CHECK: %[[ARG1:.*]]: !hlfir.expr<2xi32> +// CHECK-NEXT: %[[RES:.*]] = hlfir.dot_product %[[ARG0]] %[[ARG1]] : (!hlfir.expr<2xi32>, !hlfir.expr<2xi32>) -> i32 +// CHECK-NEXT: return +// CHECK-NEXT: } + +// arguments are expressions of assumed shape +func.func @dot_product1(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> i32 + return +} +// CHECK-LABEL: func.func @dot_product1 +// CHECK: %[[ARG0:.*]]: !hlfir.expr, +// CHECK: %[[ARG1:.*]]: !hlfir.expr +// CHECK-NEXT: %[[RES:.*]] = hlfir.dot_product %[[ARG0]] %[[ARG1]] : (!hlfir.expr, !hlfir.expr) -> i32 +// CHECK-NEXT: return +// CHECK-NEXT: } + +// arguments are boxed arrays +func.func @dot_product2(%arg0: !fir.box>, %arg1: !fir.box>) { + %res = hlfir.dot_product %arg0 %arg1 : (!fir.box>, !fir.box>) -> i32 + return +} +// CHECK-LABEL: func.func @dot_product2 +// CHECK: %[[ARG0:.*]]: !fir.box>, +// CHECK: %[[ARG1:.*]]: !fir.box> +// CHECK-NEXT: %[[RES:.*]] = hlfir.dot_product %[[ARG0]] %[[ARG1]] : (!fir.box>, !fir.box>) -> i32 +// CHECK-NEXT: return +// CHECK-NEXT: } + +// arguments are logical +func.func @dot_product3(%arg0: !fir.box>>, %arg1: !fir.box>>) { + %res = hlfir.dot_product %arg0 %arg1 : (!fir.box>>, !fir.box>>) -> !fir.logical<4> + return +} +// CHECK-LABEL: func.func @dot_product3 +// CHECK: %[[ARG0:.*]]: !fir.box>>, +// CHECK: %[[ARG1:.*]]: !fir.box>> +// CHECK-NEXT: %[[RES:.*]] = hlfir.dot_product %[[ARG0]] %[[ARG1]] : (!fir.box>>, !fir.box>>) -> !fir.logical<4> +// CHECK-NEXT: return +// CHECK-NEXT: } \ No newline at end of file diff --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir --- a/flang/test/HLFIR/invalid.fir +++ b/flang/test/HLFIR/invalid.fir @@ -508,6 +508,41 @@ return } +// ----- +func.func @bad_dot_product1(%arg0: !hlfir.expr<2xi32>, %arg1: !hlfir.expr<2x3xi32>) { + // expected-error@+1 {{'hlfir.dot_product' op both arrays must have rank 1}} + %0 = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<2xi32>, !hlfir.expr<2x3xi32>) -> i32 + return +} + +// ----- +func.func @bad_dot_product2(%arg0: !hlfir.expr<2xi32>, %arg1: !hlfir.expr<3xi32>) { + // expected-error@+1 {{'hlfir.dot_product' op both arrays must have the same size}} + %0 = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<2xi32>, !hlfir.expr<3xi32>) -> i32 + return +} + +// ----- +func.func @bad_dot_product3(%arg0: !hlfir.expr<2xi32>, %arg1: !hlfir.expr<2x!fir.logical<4>>) { + // expected-error@+1 {{'hlfir.dot_product' op if one array is logical, so should the other be}} + %0 = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<2xi32>, !hlfir.expr<2x!fir.logical<4>>) -> i32 + return +} + +// ----- +func.func @bad_dot_product4(%arg0: !hlfir.expr<2xi32>, %arg1: !hlfir.expr<2xi32>) { + // expected-error@+1 {{'hlfir.dot_product' op the result type should be a logical only if the argument types are logical}} + %0 = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<2xi32>, !hlfir.expr<2xi32>) -> !fir.logical<4> + return +} + +// ----- +func.func @bad_dot_product5(%arg0: !hlfir.expr<2xi32>, %arg1: !hlfir.expr<2xi32>) { + // expected-error@+1 {{'hlfir.dot_product' op the result must be of scalar numerical or logical type}} + %0 = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<2xi32>, !hlfir.expr<2xi32>) -> !hlfir.expr + return +} + // ----- func.func @bad_transpose1(%arg0: !hlfir.expr<2xi32>) { // expected-error@+1 {{'hlfir.transpose' op input and output arrays should have rank 2}} diff --git a/flang/test/Lower/HLFIR/dot_product.f90 b/flang/test/Lower/HLFIR/dot_product.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/HLFIR/dot_product.f90 @@ -0,0 +1,53 @@ +! Test lowering of DOT_PRODUCT intrinsic to HLFIR +! RUN: bbc -emit-fir -hlfir -o - %s 2>&1 | FileCheck %s + +! dot product with numerical arguments +subroutine dot_product1(lhs, rhs, res) + integer lhs(:), rhs(:), res + res = DOT_PRODUCT(lhs,rhs) +end subroutine +! CHECK-LABEL: func.func @_QPdot_product1 +! CHECK: %[[LHS:.*]]: !fir.box> {fir.bindc_name = "lhs"} +! CHECK: %[[RHS:.*]]: !fir.box> {fir.bindc_name = "rhs"} +! CHECK: %[[RES:.*]]: !fir.ref {fir.bindc_name = "res"} +! CHECK-DAG: %[[LHS_VAR:.*]]:2 = hlfir.declare %[[LHS]] +! CHECK-DAG: %[[RHS_VAR:.*]]:2 = hlfir.declare %[[RHS]] +! CHECK-DAG: %[[RES_VAR:.*]]:2 = hlfir.declare %[[RES]] +! CHECK-NEXT: %[[EXPR:.*]] = hlfir.dot_product %[[LHS_VAR]]#0 %[[RHS_VAR]]#0 {fastmath = #arith.fastmath} : (!fir.box>, !fir.box>) -> i32 +! CHECK-NEXT: hlfir.assign %[[EXPR]] to %[[RES_VAR]]#0 : i32, !fir.ref +! CHECK-NEXT: return +! CHECK-NEXT: } + +! dot product with logical arguments +subroutine dot_product2(lhs, rhs, res) + logical lhs(:), rhs(:), res + res = DOT_PRODUCT(lhs,rhs) +end subroutine +! CHECK-LABEL: func.func @_QPdot_product2 +! CHECK: %[[LHS:.*]]: !fir.box>> {fir.bindc_name = "lhs"} +! CHECK: %[[RHS:.*]]: !fir.box>> {fir.bindc_name = "rhs"} +! CHECK: %[[RES:.*]]: !fir.ref> {fir.bindc_name = "res"} +! CHECK-DAG: %[[LHS_VAR:.*]]:2 = hlfir.declare %[[LHS]] +! CHECK-DAG: %[[RHS_VAR:.*]]:2 = hlfir.declare %[[RHS]] +! CHECK-DAG: %[[RES_VAR:.*]]:2 = hlfir.declare %[[RES]] +! CHECK-NEXT: %[[EXPR:.*]] = hlfir.dot_product %[[LHS_VAR]]#0 %[[RHS_VAR]]#0 {fastmath = #arith.fastmath} : (!fir.box>>, !fir.box>>) -> !fir.logical<4> +! CHECK-NEXT: hlfir.assign %[[EXPR]] to %[[RES_VAR]]#0 : !fir.logical<4>, !fir.ref> +! CHECK-NEXT: return +! CHECK-NEXT: } + +! arguments aren't masked +subroutine dot_product3(lhs, rhs, res) + integer lhs(5), rhs(5), res + res = DOT_PRODUCT(lhs,rhs) +end subroutine +! CHECK-LABEL: func.func @_QPdot_product3 +! CHECK: %[[LHS:.*]]: !fir.ref> {fir.bindc_name = "lhs"} +! CHECK: %[[RHS:.*]]: !fir.ref> {fir.bindc_name = "rhs"} +! CHECK: %[[RES:.*]]: !fir.ref {fir.bindc_name = "res"} +! CHECK-DAG: %[[LHS_VAR:.*]]:2 = hlfir.declare %[[LHS]] +! CHECK-DAG: %[[RHS_VAR:.*]]:2 = hlfir.declare %[[RHS]] +! CHECK-DAG: %[[RES_VAR:.*]]:2 = hlfir.declare %[[RES]] +! CHECK-NEXT: %[[EXPR:.*]] = hlfir.dot_product %[[LHS_VAR]]#0 %[[RHS_VAR]]#0 {fastmath = #arith.fastmath} : (!fir.ref>, !fir.ref>) -> i32 +! CHECK-NEXT: hlfir.assign %[[EXPR]] to %[[RES_VAR]]#0 : i32, !fir.ref +! CHECK-NEXT: return +! CHECK-NEXT: }