diff --git a/flang/include/flang/Optimizer/Builder/Runtime/Transformational.h b/flang/include/flang/Optimizer/Builder/Runtime/Transformational.h --- a/flang/include/flang/Optimizer/Builder/Runtime/Transformational.h +++ b/flang/include/flang/Optimizer/Builder/Runtime/Transformational.h @@ -55,6 +55,10 @@ mlir::Value matrixABox, mlir::Value matrixBBox, mlir::Value resultBox); +void genMatmulTranspose(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value matrixABox, mlir::Value matrixBBox, + mlir::Value resultBox); + void genPack(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value resultBox, mlir::Value arrayBox, mlir::Value maskBox, mlir::Value vectorBox); diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -269,6 +269,8 @@ template mlir::Value genMask(mlir::Type, llvm::ArrayRef); fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef); + fir::ExtendedValue genMatmulTranspose(mlir::Type, + llvm::ArrayRef); fir::ExtendedValue genMaxloc(mlir::Type, llvm::ArrayRef); fir::ExtendedValue genMaxval(mlir::Type, llvm::ArrayRef); fir::ExtendedValue genMerge(mlir::Type, llvm::ArrayRef); @@ -679,6 +681,10 @@ &I::genMatmul, {{{"matrix_a", asAddr}, {"matrix_b", asAddr}}}, /*isElemental=*/false}, + {"matmul_transpose", + &I::genMatmulTranspose, + {{{"matrix_a", asAddr}, {"matrix_b", asAddr}}}, + /*isElemental=*/false}, {"max", &I::genExtremum}, {"maxloc", &I::genMaxloc, @@ -4015,6 +4021,33 @@ return readAndAddCleanUp(resultMutableBox, resultType, "MATMUL"); } +// MATMUL_TRANSPOSE +fir::ExtendedValue +IntrinsicLibrary::genMatmulTranspose(mlir::Type resultType, + llvm::ArrayRef args) { + assert(args.size() == 2); + + // Handle required matmul_transpose arguments + fir::BoxValue matrixTmpA = builder.createBox(loc, args[0]); + mlir::Value matrixA = fir::getBase(matrixTmpA); + fir::BoxValue matrixTmpB = builder.createBox(loc, args[1]); + mlir::Value matrixB = fir::getBase(matrixTmpB); + unsigned resultRank = + (matrixTmpA.rank() == 1 || matrixTmpB.rank() == 1) ? 1 : 2; + + // Create mutable fir.box to be passed to the runtime for the result. + mlir::Type resultArrayType = builder.getVarLenSeqTy(resultType, resultRank); + fir::MutableBoxValue resultMutableBox = + fir::factory::createTempMutableBox(builder, loc, resultArrayType); + mlir::Value resultIrBox = + fir::factory::getMutableIRBox(builder, loc, resultMutableBox); + // Call runtime. The runtime is allocating the result. + fir::runtime::genMatmulTranspose(builder, loc, resultIrBox, matrixA, matrixB); + // Read result from mutable fir.box and add it to the list of temps to be + // finalized by the StatementContext. + return readAndAddCleanUp(resultMutableBox, resultType, "MATMUL_TRANSPOSE"); +} + // MERGE fir::ExtendedValue IntrinsicLibrary::genMerge(mlir::Type, diff --git a/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp b/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp --- a/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp +++ b/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp @@ -13,6 +13,7 @@ #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/Runtime/RTBuilder.h" #include "flang/Optimizer/Builder/Todo.h" +#include "flang/Runtime/matmul-transpose.h" #include "flang/Runtime/matmul.h" #include "flang/Runtime/transformational.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -351,6 +352,23 @@ builder.create(loc, func, args); } +/// Generate call to MatmulTranspose intrinsic runtime routine. +void fir::runtime::genMatmulTranspose(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Value resultBox, + mlir::Value matrixABox, + mlir::Value matrixBBox) { + auto func = + fir::runtime::getRuntimeFunc(loc, builder); + auto fTy = func.getFunctionType(); + auto sourceFile = fir::factory::locationToFilename(builder, loc); + auto sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(4)); + auto args = + fir::runtime::createArguments(builder, loc, fTy, resultBox, matrixABox, + matrixBBox, sourceFile, sourceLine); + builder.create(loc, func, args); +} + /// Generate call to Pack intrinsic runtime routine. void fir::runtime::genPack(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value resultBox, mlir::Value arrayBox, diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp --- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp @@ -257,6 +257,39 @@ } }; +struct MatmulTransposeOpConversion + : public HlfirIntrinsicConversion { + using HlfirIntrinsicConversion< + hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion; + + mlir::LogicalResult + matchAndRewrite(hlfir::MatmulTransposeOp multranspose, + mlir::PatternRewriter &rewriter) const override { + fir::KindMapping kindMapping{rewriter.getContext()}; + fir::FirOpBuilder builder{rewriter, kindMapping}; + const mlir::Location &loc = multranspose->getLoc(); + + mlir::Value lhs = multranspose.getLhs(); + mlir::Value rhs = multranspose.getRhs(); + llvm::SmallVector inArgs; + inArgs.push_back({lhs, lhs.getType()}); + inArgs.push_back({rhs, rhs.getType()}); + + auto *argLowering = fir::getIntrinsicArgumentLowering("matmul"); + llvm::SmallVector args = + lowerArguments(multranspose, inArgs, rewriter, argLowering); + + mlir::Type scalarResultType = + hlfir::getFortranElementType(multranspose.getType()); + + auto [resultExv, mustBeFreed] = fir::genIntrinsicCall( + builder, loc, "matmul_transpose", scalarResultType, args); + + processReturnValue(multranspose, resultExv, mustBeFreed, builder, rewriter); + return mlir::success(); + } +}; + class LowerHLFIRIntrinsics : public hlfir::impl::LowerHLFIRIntrinsicsBase { public: @@ -271,13 +304,14 @@ mlir::ModuleOp module = this->getOperation(); mlir::MLIRContext *context = &getContext(); mlir::RewritePatternSet patterns(context); - patterns.insert( - context); + patterns.insert(context); mlir::ConversionTarget target(*context); target.addLegalDialect(); - target.addIllegalOp(); + target.addIllegalOp(); target.markUnknownOpDynamicallyLegal( [](mlir::Operation *) { return true; }); if (mlir::failed( diff --git a/flang/test/HLFIR/mul_transpose.f90 b/flang/test/HLFIR/mul_transpose.f90 --- a/flang/test/HLFIR/mul_transpose.f90 +++ b/flang/test/HLFIR/mul_transpose.f90 @@ -1,6 +1,7 @@ ! RUN: bbc -emit-fir -hlfir %s -o - | FileCheck --check-prefix CHECK-BASE --check-prefix CHECK-ALL %s ! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --canonicalize | FileCheck --check-prefix CHECK-CANONICAL --check-prefix CHECK-ALL %s ! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --lower-hlfir-intrinsics | FileCheck --check-prefix CHECK-LOWERING --check-prefix CHECK-ALL %s +! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --canonicalize | fir-opt --lower-hlfir-intrinsics | FileCheck --check-prefix CHECK-LOWERING-OPT --check-prefix CHECK-ALL %s ! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --lower-hlfir-intrinsics | fir-opt --bufferize-hlfir | FileCheck --check-prefix CHECK-BUFFERING --check-prefix CHECK-ALL %s ! Test passing a hlfir.expr from one intrinsic to another @@ -56,6 +57,20 @@ ! CHECK-LOWERING-NEXT: hlfir.destroy %[[MUL_EXPR]] ! CHECK-LOWERING-NEXT: hlfir.destroy %[[TRANSPOSE_EXPR]] +! CHECK-LOWERING-OPT: %[[LHS_BOX:.*]] = fir.embox %[[A_DECL]]#1(%{{.*}}) +! CHECK-LOWERING-OPT: %[[B_BOX:.*]] = fir.embox %[[B_DECL]]#1(%{{.*}}) +! CHECK-LOWERING-OPT: %[[MUL_CONV_RES:.*]] = fir.convert %[[MUL_RES_BOX:.*]] : (!fir.ref>>>) -> !fir.ref> +! CHECK-LOWERING-OPT: %[[LHS_CONV:.*]] = fir.convert %[[LHS_BOX]] : (!fir.box>) -> !fir.box +! CHECK-LOWERING-OPT: %[[B_BOX_CONV:.*]] = fir.convert %[[B_BOX]] : (!fir.box>) -> !fir.box +! CHECK-LOWERING-OPT: fir.call @_FortranAMatmulTranspose(%[[MUL_CONV_RES]], %[[LHS_CONV]], %[[B_BOX_CONV]], %[[LOC_STR2:.*]], %[[LOC_N2:.*]]) +! CHECK-LOWERING-OPT: %[[MUL_RES_LD:.*]] = fir.load %[[MUL_RES_BOX:.*]] +! CHECK-LOWERING-OPT: %[[MUL_RES_ADDR:.*]] = fir.box_addr %[[MUL_RES_LD]] +! CHECK-LOWERING-OPT: %[[MUL_RES_VAR:.*]]:2 = hlfir.declare %[[MUL_RES_ADDR]]({{.*}}) {uniq_name = ".tmp.intrinsic_result"} +! CHECK-LOWERING-OPT: %[[TRUE2:.*]] = arith.constant true +! CHECK-LOWERING-OPT: %[[MUL_EXPR:.*]] = hlfir.as_expr %[[MUL_RES_VAR]]#0 move %[[TRUE2]] : (!fir.box>, i1) -> !hlfir.expr +! CHECK-LOWERING-OPT: hlfir.assign %[[MUL_EXPR]] to %[[RES_DECL]]#0 : !hlfir.expr, !fir.ref> +! CHECK-LOWERING-OPT: hlfir.destroy %[[MUL_EXPR]] + ! [argument handling unchanged] ! CHECK-BUFFERING: fir.call @_FortranATranspose( ! CHECK-BUFFERING: %[[TRANSPOSE_RES_LD:.*]] = fir.load %[[TRANSPOSE_RES_BOX:.*]]