diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.h b/flang/include/flang/Optimizer/HLFIR/HLFIROps.h --- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.h +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.h @@ -15,6 +15,7 @@ #include "flang/Optimizer/Dialect/FortranVariableInterface.h" #include "flang/Optimizer/HLFIR/HLFIRDialect.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td --- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td @@ -382,6 +382,9 @@ $lhs $rhs attr-dict `:` functional-type(operands, results) }]; + // MATMUL(TRANSPOSE(...), ...) => hlfir.matmul_transpose + let hasCanonicalizeMethod = 1; + let hasVerifier = 1; } diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/TypeSwitch.h" +#include #include #include @@ -638,6 +639,42 @@ return mlir::success(); } +mlir::LogicalResult +hlfir::MatmulOp::canonicalize(MatmulOp matmulOp, + mlir::PatternRewriter &rewriter) { + auto getNumUses = [](mlir::Value val) { + auto users = val.getUsers(); + return std::distance(users.begin(), users.end()); + }; + mlir::Value lhs = matmulOp.getLhs(); + + // Rewrite MATMUL(TRANSPOSE(lhs), rhs) => hlfir.matmul_transpose lhs, rhs + if (auto transposeOp = lhs.getDefiningOp()) { + // 2 uses: one for the hlfir.matmul and one for hlfir.destroy + if (getNumUses(transposeOp.getResult()) <= 2) { + mlir::Location loc = matmulOp.getLoc(); + mlir::Type resultTy = matmulOp.getResult().getType(); + auto matmulTransposeOp = rewriter.create( + loc, resultTy, transposeOp.getArray(), matmulOp.getRhs()); + + // we don't need to remove any hlfir.destroy because it will be needed for + // the new intrinsic result anyway + rewriter.replaceOp(matmulOp, matmulTransposeOp.getResult()); + + // but we do need to get rid of the hlfir.destroy for the hlfir.transpose + // result (which is entirely removed) + for (mlir::Operation *user : transposeOp->getResult(0).getUsers()) + if (auto destroyOp = mlir::dyn_cast_or_null(user)) + rewriter.eraseOp(destroyOp); + rewriter.eraseOp(transposeOp); + + return mlir::success(); + } + } + + return mlir::failure(); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/flang/test/HLFIR/mul_transpose.f90 b/flang/test/HLFIR/mul_transpose.f90 --- a/flang/test/HLFIR/mul_transpose.f90 +++ b/flang/test/HLFIR/mul_transpose.f90 @@ -1,4 +1,5 @@ ! RUN: bbc -emit-fir -hlfir %s -o - | FileCheck --check-prefix CHECK-BASE --check-prefix CHECK-ALL %s +! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --canonicalize | FileCheck --check-prefix CHECK-CANONICAL --check-prefix CHECK-ALL %s ! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --lower-hlfir-intrinsics | FileCheck --check-prefix CHECK-LOWERING --check-prefix CHECK-ALL %s ! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --lower-hlfir-intrinsics | fir-opt --bufferize-hlfir | FileCheck --check-prefix CHECK-BUFFERING --check-prefix CHECK-ALL %s @@ -22,6 +23,10 @@ ! CHECK-BASE-NEXT: hlfir.destroy %[[MATMUL_RES]] ! CHECK-BASE-NEXT: hlfir.destroy %[[TRANSPOSE_RES]] +! CHECK-CANONICAL-NEXT: %[[CHAIN_RES:.*]] = hlfir.matmul_transpose %[[A_DECL]]#0 %[[B_DECL]]#0 : (!fir.ref>, !fir.ref>) -> !hlfir.expr<1x2xf32> +! CHECK-CANONICAL-NEXT: hlfir.assign %[[CHAIN_RES]] to %[[RES_DECL]]#0 : !hlfir.expr<1x2xf32>, !fir.ref> +! CHECK-CANONICAL-NEXT: hlfir.destroy %[[CHAIN_RES]] + ! CHECK-LOWERING: %[[A_BOX:.*]] = fir.embox %[[A_DECL]]#1(%{{.*}}) ! CHECK-LOWERING: %[[TRANSPOSE_CONV_RES:.*]] = fir.convert %[[TRANSPOSE_RES_BOX:.*]] : (!fir.ref>>>) -> !fir.ref> ! CHECK-LOWERING: %[[A_BOX_CONV:.*]] = fir.convert %[[A_BOX]] : (!fir.box>) -> !fir.box