diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h --- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h +++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h @@ -71,6 +71,7 @@ bool isFortranScalarNumericalType(mlir::Type); bool isFortranNumericalArrayObject(mlir::Type); +bool isFortranNumericalOrLogicalArrayObject(mlir::Type); bool isPassByRefOrIntegerType(mlir::Type); bool isI1Type(mlir::Type); // scalar i1 or logical, or sequence of logical (via (boxed?) array or expr) diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td --- a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td @@ -113,6 +113,11 @@ def AnyFortranNumericalArrayObject : Type; +def IsFortranNumericalOrLogicalArrayObjectPred + : CPred<"::hlfir::isFortranNumericalOrLogicalArrayObject($_self)">; +def AnyFortranNumericalOrLogicalArrayObject : Type; + def IsPassByRefOrIntegerTypePred : CPred<"::hlfir::isPassByRefOrIntegerType($_self)">; def AnyPassByRefOrIntegerType : Type]> { + let summary = "MATMUL transformational intrinsic"; + let description = [{ + Matrix multiplication + }]; + + let arguments = (ins + AnyFortranNumericalOrLogicalArrayObject:$lhs, + AnyFortranNumericalOrLogicalArrayObject:$rhs, + DefaultValuedAttr:$fastmath + ); + + let results = (outs hlfir_ExprType); + + let assemblyFormat = [{ + $lhs $rhs attr-dict `:` functional-type(operands, results) + }]; + + let builders = [OpBuilder<(ins "mlir::Value":$lhs, + "mlir::Value":$rhs, + "mlir::Type":$resultType)>]; + + let hasVerifier = 1; +} + def hlfir_AssociateOp : hlfir_Op<"associate", [AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let summary = "Create a variable from an expression value"; diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp --- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp @@ -115,6 +115,18 @@ return false; } +bool hlfir::isFortranNumericalOrLogicalArrayObject(mlir::Type type) { + if (isBoxAddressType(type)) + return false; + if (auto arrayTy = + getFortranElementOrSequenceType(type).dyn_cast()) { + mlir::Type eleTy = arrayTy.getEleTy(); + return isFortranScalarNumericalType(eleTy) || + mlir::isa(eleTy); + } + return false; +} + bool hlfir::isPassByRefOrIntegerType(mlir::Type type) { mlir::Type unwrappedType = fir::unwrapPassByRefType(type); return fir::isa_integer(unwrappedType); diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -510,6 +510,76 @@ build(builder, result, resultType, array, dim, mask); } +//===----------------------------------------------------------------------===// +// MatmulOp +//===----------------------------------------------------------------------===// + +mlir::LogicalResult hlfir::MatmulOp::verify() { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); + fir::SequenceType lhsTy = + hlfir::getFortranElementOrSequenceType(lhs.getType()) + .cast(); + fir::SequenceType rhsTy = + hlfir::getFortranElementOrSequenceType(rhs.getType()) + .cast(); + llvm::ArrayRef lhsShape = lhsTy.getShape(); + llvm::ArrayRef rhsShape = rhsTy.getShape(); + std::size_t lhsRank = lhsShape.size(); + std::size_t rhsRank = rhsShape.size(); + mlir::Type lhsEleTy = lhsTy.getEleTy(); + mlir::Type rhsEleTy = rhsTy.getEleTy(); + hlfir::ExprType resultTy = getResult().getType().cast(); + llvm::ArrayRef resultShape = resultTy.getShape(); + mlir::Type resultEleTy = resultTy.getEleTy(); + + if (((lhsRank != 1) && (lhsRank != 2)) || ((rhsRank != 1) && (rhsRank != 2))) + return emitOpError("array must have either rank 1 or rank 2"); + + if ((lhsRank == 1) && (rhsRank == 1)) + return emitOpError("at least one array must have rank 2"); + + if (mlir::isa(lhsEleTy) != + mlir::isa(rhsEleTy)) + return emitOpError("if one array is logical, so should the other be"); + + int64_t lastLhsDim = lhsShape[lhsRank - 1]; + int64_t firstRhsDim = rhsShape[0]; + constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent(); + if (lastLhsDim != firstRhsDim) + if ((lastLhsDim != unknownExtent) && (firstRhsDim != unknownExtent)) + return emitOpError( + "the last dimension of LHS should match the first dimension of RHS"); + + if (mlir::isa(lhsEleTy) != + mlir::isa(resultEleTy)) + return emitOpError("the result type should be a logical only if the " + "argument types are logical"); + + llvm::SmallVector expectedResultShape; + if (lhsRank == 2) { + if (rhsRank == 2) { + expectedResultShape.push_back(lhsShape[0]); + expectedResultShape.push_back(rhsShape[1]); + } else { + // rhsRank == 1 + expectedResultShape.push_back(lhsShape[0]); + } + } else { + // lhsRank == 1 + // rhsRank == 2 + expectedResultShape.push_back(rhsShape[1]); + } + if (resultShape.size() != expectedResultShape.size()) + return emitOpError("incorrect result shape"); + if (resultShape[0] != expectedResultShape[0]) + return emitOpError("incorrect result shape"); + if (resultShape.size() == 2 && resultShape[1] != expectedResultShape[1]) + return emitOpError("incorrect result shape"); + + return mlir::success(); +} + //===----------------------------------------------------------------------===// // AssociateOp //===----------------------------------------------------------------------===// diff --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir --- a/flang/test/HLFIR/invalid.fir +++ b/flang/test/HLFIR/invalid.fir @@ -319,3 +319,59 @@ // expected-error@+1 {{'hlfir.sum' op result rank must be one less than ARRAY}} %0 = hlfir.sum %arg0 dim %arg1 mask %arg2 : (!hlfir.expr, i32, !fir.box>) -> !hlfir.expr } + +// ----- +func.func @bad_matmul1(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + // expected-error@+1 {{'hlfir.matmul' op array must have either rank 1 or rank 2}} + %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr + return +} + +// ----- +func.func @bad_matmul2(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + // expected-error@+1 {{'hlfir.matmul' op at least one array must have rank 2}} + %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr + return +} + +// ----- +func.func @bad_matmul3(%arg0: !hlfir.expr>, %arg1: !hlfir.expr) { + // expected-error@+1 {{'hlfir.matmul' op if one array is logical, so should the other be}} + %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr>, !hlfir.expr) -> !hlfir.expr + return +} + +// ----- +func.func @bad_matmul4(%arg0: !hlfir.expr, %arg1: !hlfir.expr<200x?xi32>) { + // expected-error@+1 {{'hlfir.matmul' op the last dimension of LHS should match the first dimension of RHS}} + %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr, !hlfir.expr<200x?xi32>) -> !hlfir.expr + return +} + +// ----- +func.func @bad_matmul5(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + // expected-error@+1 {{'hlfir.matmul' op the result type should be a logical only if the argument types are logical}} + %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr> + return +} + +// ----- +func.func @bad_matmul6(%arg0: !hlfir.expr<1x2xi32>, %arg1: !hlfir.expr<2x3xi32>) { + // expected-error@+1 {{'hlfir.matmul' op incorrect result shape}} + %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<1x2xi32>, !hlfir.expr<2x3xi32>) -> !hlfir.expr<10x30xi32> + return +} + +// ----- +func.func @bad_matmul7(%arg0: !hlfir.expr<1x2xi32>, %arg1: !hlfir.expr<2xi32>) { + // expected-error@+1 {{'hlfir.matmul' op incorrect result shape}} + %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<1x2xi32>, !hlfir.expr<2xi32>) -> !hlfir.expr<1x3xi32> + return +} + +// ----- +func.func @bad_matmul8(%arg0: !hlfir.expr<2xi32>, %arg1: !hlfir.expr<2x3xi32>) { + // expected-error@+1 {{'hlfir.matmul' op incorrect result shape}} + %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<2xi32>, !hlfir.expr<2x3xi32>) -> !hlfir.expr<1x3xi32> + return +} diff --git a/flang/test/HLFIR/matmul.fir b/flang/test/HLFIR/matmul.fir new file mode 100644 --- /dev/null +++ b/flang/test/HLFIR/matmul.fir @@ -0,0 +1,99 @@ +// Test hlfir.matmul operation parse, verify (no errors), and unparse + +// RUN: fir-opt %s | fir-opt | FileCheck %s + +// arguments are expressions of known shape +func.func @matmul0(%arg0: !hlfir.expr<2x2xi32>, %arg1: !hlfir.expr<2x2xi32>) { + %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<2x2xi32>, !hlfir.expr<2x2xi32>) -> !hlfir.expr<2x2xi32> + return +} +// CHECK-LABEL: func.func @matmul0 +// CHECK: %[[ARG0:.*]]: !hlfir.expr<2x2xi32>, +// CHECK: %[[ARG1:.*]]: !hlfir.expr<2x2xi32>) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr<2x2xi32>, !hlfir.expr<2x2xi32>) -> !hlfir.expr<2x2xi32> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// arguments are expressions of assumed shape +func.func @matmul1(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr + return +} +// CHECK-LABEL: func.func @matmul1 +// CHECK: %[[ARG0:.*]]: !hlfir.expr, +// CHECK: %[[ARG1:.*]]: !hlfir.expr) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr +// CHECK-NEXT: return +// CHECK-NEXT: } + +// arguments are expressions where only some dimensions are known #1 +func.func @matmul2(%arg0: !hlfir.expr<2x?xi32>, %arg1: !hlfir.expr) { + %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<2x?xi32>, !hlfir.expr) -> !hlfir.expr<2x2xi32> + return +} +// CHECK-LABEL: func.func @matmul2 +// CHECK: %[[ARG0:.*]]: !hlfir.expr<2x?xi32>, +// CHECK: %[[ARG1:.*]]: !hlfir.expr) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr<2x?xi32>, !hlfir.expr) -> !hlfir.expr<2x2xi32> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// arguments are expressions where only some dimensions are known #2 +func.func @matmul3(%arg0: !hlfir.expr, %arg1: !hlfir.expr<2x?xi32>) { + %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr, !hlfir.expr<2x?xi32>) -> !hlfir.expr + return +} +// CHECK-LABEL: func.func @matmul3 +// CHECK: %[[ARG0:.*]]: !hlfir.expr, +// CHECK: %[[ARG1:.*]]: !hlfir.expr<2x?xi32>) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr, !hlfir.expr<2x?xi32>) -> !hlfir.expr +// CHECK-NEXT: return +// CHECK-NEXT: } + +// arguments are logicals +func.func @matmul4(%arg0: !hlfir.expr>, %arg1: !hlfir.expr>) { + %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr>, !hlfir.expr>) -> !hlfir.expr> + return +} +// CHECK-LABEL: func.func @matmul4 +// CHECK: %[[ARG0:.*]]: !hlfir.expr>, +// CHECK: %[[ARG1:.*]]: !hlfir.expr>) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr>, !hlfir.expr>) -> !hlfir.expr> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// lhs is rank 1 +func.func @matmul5(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr + return +} +// CHECK-LABEL: func.func @matmul5 +// CHECK: %[[ARG0:.*]]: !hlfir.expr, +// CHECK: %[[ARG1:.*]]: !hlfir.expr) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr +// CHECK-NEXT: return +// CHECK-NEXT: } + +// rhs is rank 1 +func.func @matmul6(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr + return +} +// CHECK-LABEL: func.func @matmul6 +// CHECK: %[[ARG0:.*]]: !hlfir.expr, +// CHECK: %[[ARG1:.*]]: !hlfir.expr) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr +// CHECK-NEXT: return +// CHECK-NEXT: } + +// arguments are boxed arrays +func.func @matmul7(%arg0: !fir.box>, %arg1: !fir.box>) { + %res = hlfir.matmul %arg0 %arg1 : (!fir.box>, !fir.box>) -> !hlfir.expr<2x2xf32> + return +} +// CHECK-LABEL: func.func @matmul7 +// CHECK: %[[ARG0:.*]]: !fir.box>, +// CHECK: %[[ARG1:.*]]: !fir.box>) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!fir.box>, !fir.box>) -> !hlfir.expr<2x2xf32> +// CHECK-NEXT: return +// CHECK-NEXT: }