diff --git a/flang/docs/HighLevelFIR.md b/flang/docs/HighLevelFIR.md --- a/flang/docs/HighLevelFIR.md +++ b/flang/docs/HighLevelFIR.md @@ -652,7 +652,6 @@ %element = hlfir.apply %array_expr %i, %j: (hlfir.expr) -> i32 ``` - #### Introducing operations for transformational intrinsic functions Motivation: Represent transformational intrinsics functions at a high-level so @@ -701,6 +700,39 @@ - selected_char_kind, selected_int_kind, selected_real_kind that returns scalar integers +#### Introducing operations for composed intrinsic functions + +Motivation: optimize commonly composed intrinsic functions (e.g. +MATMUL(TRANSPOSE(a), b)). This optimization is implemented in Classic Flang. + +An operation and runtime function will be added for each commonly used +composition of intrinsic functions. The operation will be the canonical way to +write this chained operation (the MLIR canonicalization pass will rewrite the +operations for the composed intrinsics into this one operation). + +These new operations will be treated as though they were standard +transformational intrinsic functions. + +The composed intrinsic operation will return a hlfir.expr. The arguments +may be hlfir.expr, boxed arrays, simple scalar types (e.g. i32, f32), or +variables. + +To keep things simple, these operations will only match one form of the composed +intrinsic functions: therefore there will be no optional arguments. + +Syntax: +``` +%res = hlfir."intrinsic_name" %expr_or_var, ... +``` + +The composed intrinsic operation will be lowered to a `fir.call` to the newly +added runtime implementation of the operation. + +These operations should not be added where the only improvement is to avoid +creating a temporary intermediate buffer which would otherwise be removed by +intelligent bufferization of a hlfir.expr. Similarly, these should not replace +profitable uses of hlfir.elemental. + #### Introducing operations for character operations and elemental intrinsic functions diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td --- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td @@ -402,6 +402,29 @@ let hasVerifier = 1; } +def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose", + [DeclareOpInterfaceMethods]> { + let summary = "Optimized MATMUL(TRANSPOSE(...), ...)"; + let description = [{ + Matrix multiplication where the left hand side is transposed + }]; + + let arguments = (ins + AnyFortranNumericalOrLogicalArrayObject:$lhs, + AnyFortranNumericalOrLogicalArrayObject:$rhs, + DefaultValuedAttr:$fastmath + ); + + let results = (outs hlfir_ExprType); + + let assemblyFormat = [{ + $lhs $rhs attr-dict `:` functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + def hlfir_AssociateOp : hlfir_Op<"associate", [AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let summary = "Create a variable from an expression value"; diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -668,6 +668,71 @@ return mlir::success(); } +//===----------------------------------------------------------------------===// +// MatmulTransposeOp +//===----------------------------------------------------------------------===// + +mlir::LogicalResult hlfir::MatmulTransposeOp::verify() { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); + fir::SequenceType lhsTy = + hlfir::getFortranElementOrSequenceType(lhs.getType()) + .cast(); + fir::SequenceType rhsTy = + hlfir::getFortranElementOrSequenceType(rhs.getType()) + .cast(); + llvm::ArrayRef lhsShape = lhsTy.getShape(); + llvm::ArrayRef rhsShape = rhsTy.getShape(); + std::size_t lhsRank = lhsShape.size(); + std::size_t rhsRank = rhsShape.size(); + mlir::Type lhsEleTy = lhsTy.getEleTy(); + mlir::Type rhsEleTy = rhsTy.getEleTy(); + hlfir::ExprType resultTy = getResult().getType().cast(); + llvm::ArrayRef resultShape = resultTy.getShape(); + mlir::Type resultEleTy = resultTy.getEleTy(); + + // lhs must have rank 2 for the transpose to be valid + if ((lhsRank != 2) || ((rhsRank != 1) && (rhsRank != 2))) + return emitOpError("array must have either rank 1 or rank 2"); + + if (mlir::isa(lhsEleTy) != + mlir::isa(rhsEleTy)) + return emitOpError("if one array is logical, so should the other be"); + + // for matmul we compare the last dimension of lhs with the first dimension of + // rhs, but for MatmulTranspose, dimensions of lhs are inverted by the + // transpose + int64_t firstLhsDim = lhsShape[0]; + int64_t firstRhsDim = rhsShape[0]; + constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent(); + if (firstLhsDim != firstRhsDim) + if ((firstLhsDim != unknownExtent) && (firstRhsDim != unknownExtent)) + return emitOpError( + "the first dimension of LHS should match the first dimension of RHS"); + + if (mlir::isa(lhsEleTy) != + mlir::isa(resultEleTy)) + return emitOpError("the result type should be a logical only if the " + "argument types are logical"); + + llvm::SmallVector expectedResultShape; + if (rhsRank == 2) { + expectedResultShape.push_back(lhsShape[1]); + expectedResultShape.push_back(rhsShape[1]); + } else { + // rhsRank == 1 + expectedResultShape.push_back(lhsShape[1]); + } + if (resultShape.size() != expectedResultShape.size()) + return emitOpError("incorrect result shape"); + if (resultShape[0] != expectedResultShape[0]) + return emitOpError("incorrect result shape"); + if (resultShape.size() == 2 && resultShape[1] != expectedResultShape[1]) + return emitOpError("incorrect result shape"); + + return mlir::success(); +} + //===----------------------------------------------------------------------===// // AssociateOp //===----------------------------------------------------------------------===// diff --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir --- a/flang/test/HLFIR/invalid.fir +++ b/flang/test/HLFIR/invalid.fir @@ -397,6 +397,48 @@ return } +// ----- +func.func @bad_matmultranspose1(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + // expected-error@+1 {{'hlfir.matmul_transpose' op array must have either rank 1 or rank 2}} + %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr + return +} + +// ----- +func.func @bad_matmultranspose2(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + // expected-error@+1 {{'hlfir.matmul_transpose' op array must have either rank 1 or rank 2}} + %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr + return +} + +// ----- +func.func @bad_matmultranspose3(%arg0: !hlfir.expr>, %arg1: !hlfir.expr) { + // expected-error@+1 {{'hlfir.matmul_transpose' op if one array is logical, so should the other be}} + %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr>, !hlfir.expr) -> !hlfir.expr + return +} + +// ----- +func.func @bad_matmultranspose5(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + // expected-error@+1 {{'hlfir.matmul_transpose' op the result type should be a logical only if the argument types are logical}} + %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr> + return +} + +// ----- +func.func @bad_matmultranspose6(%arg0: !hlfir.expr<2x1xi32>, %arg1: !hlfir.expr<2x3xi32>) { + // expected-error@+1 {{'hlfir.matmul_transpose' op incorrect result shape}} + %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<2x1xi32>, !hlfir.expr<2x3xi32>) -> !hlfir.expr<10x30xi32> + return +} + +// ----- +func.func @bad_matmultranspose7(%arg0: !hlfir.expr<2x1xi32>, %arg1: !hlfir.expr<2xi32>) { + // expected-error@+1 {{'hlfir.matmul_transpose' op incorrect result shape}} + %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<2x1xi32>, !hlfir.expr<2xi32>) -> !hlfir.expr<1x3xi32> + return +} + // ----- func.func @bad_assign_1(%arg0: !fir.box>, %arg1: !fir.box>) { // expected-error@+1 {{'hlfir.assign' op lhs must be an allocatable when `realloc` is set}} diff --git a/flang/test/HLFIR/matmul_transpose.fir b/flang/test/HLFIR/matmul_transpose.fir new file mode 100644 --- /dev/null +++ b/flang/test/HLFIR/matmul_transpose.fir @@ -0,0 +1,87 @@ +// Test hlfir.matmul_transpose operation parse, verify (no errors), and unparse + +// RUN: fir-opt %s | fir-opt | FileCheck %s + +// arguments are expressions of known shape +func.func @matmul_transpose0(%arg0: !hlfir.expr<2x2xi32>, %arg1: !hlfir.expr<2x2xi32>) { + %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<2x2xi32>, !hlfir.expr<2x2xi32>) -> !hlfir.expr<2x2xi32> + return +} +// CHECK-LABEL: func.func @matmul_transpose0 +// CHECK: %[[ARG0:.*]]: !hlfir.expr<2x2xi32>, +// CHECK: %[[ARG1:.*]]: !hlfir.expr<2x2xi32>) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr<2x2xi32>, !hlfir.expr<2x2xi32>) -> !hlfir.expr<2x2xi32> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// arguments are expressions of assumed shape +func.func @matmul_transpose1(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr + return +} +// CHECK-LABEL: func.func @matmul_transpose1 +// CHECK: %[[ARG0:.*]]: !hlfir.expr, +// CHECK: %[[ARG1:.*]]: !hlfir.expr) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr +// CHECK-NEXT: return +// CHECK-NEXT: } + +// arguments are expressions where only some dimensions are known #1 +func.func @matmul_transpose2(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr<2x2xi32> + return +} +// CHECK-LABEL: func.func @matmul_transpose2 +// CHECK: %[[ARG0:.*]]: !hlfir.expr, +// CHECK: %[[ARG1:.*]]: !hlfir.expr) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr<2x2xi32> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// arguments are expressions where only some dimensions are known #2 +func.func @matmul_transpose3(%arg0: !hlfir.expr<2x?xi32>, %arg1: !hlfir.expr<2x?xi32>) { + %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<2x?xi32>, !hlfir.expr<2x?xi32>) -> !hlfir.expr + return +} +// CHECK-LABEL: func.func @matmul_transpose3 +// CHECK: %[[ARG0:.*]]: !hlfir.expr<2x?xi32>, +// CHECK: %[[ARG1:.*]]: !hlfir.expr<2x?xi32>) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr<2x?xi32>, !hlfir.expr<2x?xi32>) -> !hlfir.expr +// CHECK-NEXT: return +// CHECK-NEXT: } + +// arguments are logicals +func.func @matmul_transpose4(%arg0: !hlfir.expr>, %arg1: !hlfir.expr>) { + %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr>, !hlfir.expr>) -> !hlfir.expr> + return +} +// CHECK-LABEL: func.func @matmul_transpose4 +// CHECK: %[[ARG0:.*]]: !hlfir.expr>, +// CHECK: %[[ARG1:.*]]: !hlfir.expr>) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr>, !hlfir.expr>) -> !hlfir.expr> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// rhs is rank 1 +func.func @matmul_transpose6(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr + return +} +// CHECK-LABEL: func.func @matmul_transpose6 +// CHECK: %[[ARG0:.*]]: !hlfir.expr, +// CHECK: %[[ARG1:.*]]: !hlfir.expr) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr +// CHECK-NEXT: return +// CHECK-NEXT: } + +// arguments are boxed arrays +func.func @matmul_transpose7(%arg0: !fir.box>, %arg1: !fir.box>) { + %res = hlfir.matmul_transpose %arg0 %arg1 : (!fir.box>, !fir.box>) -> !hlfir.expr<2x2xf32> + return +} +// CHECK-LABEL: func.func @matmul_transpose7 +// CHECK: %[[ARG0:.*]]: !fir.box>, +// CHECK: %[[ARG1:.*]]: !fir.box>) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!fir.box>, !fir.box>) -> !hlfir.expr<2x2xf32> +// CHECK-NEXT: return +// CHECK-NEXT: }