diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -22,6 +22,7 @@ #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/HLFIR/HLFIRDialect.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/HLFIR/Passes.h" #include "flang/Optimizer/Support/FIRContext.h" @@ -29,6 +30,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" +#include #include namespace hlfir { @@ -518,7 +520,6 @@ const llvm::ArrayRef &args, mlir::ConversionPatternRewriter &rewriter, const fir::IntrinsicArgumentLoweringRules *argLowering) const { - assert(args.size() == 3 && "Transformational intrinsics have 3 args"); mlir::Location loc = op->getLoc(); fir::KindMapping kindMapping{rewriter.getContext()}; fir::FirOpBuilder builder{rewriter, kindMapping}; @@ -648,6 +649,39 @@ } }; +struct MatmulOpConversion : public HlfirIntrinsicConversion { + using HlfirIntrinsicConversion::HlfirIntrinsicConversion; + + mlir::LogicalResult + matchAndRewrite(hlfir::MatmulOp matmul, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + fir::KindMapping kindMapping{rewriter.getContext()}; + fir::FirOpBuilder builder{rewriter, kindMapping}; + const mlir::Location &loc = matmul->getLoc(); + HLFIRListener listener{builder, rewriter}; + builder.setListener(&listener); + + mlir::Value lhs = matmul.getLhs(); + mlir::Value rhs = matmul.getRhs(); + llvm::SmallVector inArgs; + inArgs.push_back({lhs, lhs.getType()}); + inArgs.push_back({rhs, rhs.getType()}); + + auto *argLowering = fir::getIntrinsicArgumentLowering("matmul"); + llvm::SmallVector args = + lowerArguments(matmul, inArgs, rewriter, argLowering); + + mlir::Type scalarResultType = + hlfir::getFortranElementType(matmul.getType()); + + auto [resultExv, mustBeFreed] = + fir::genIntrinsicCall(builder, loc, "matmul", scalarResultType, args); + + processReturnValue(matmul, resultExv, mustBeFreed, builder, rewriter); + return mlir::success(); + } +}; + class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase { public: void runOnOperation() override { @@ -661,12 +695,11 @@ auto module = this->getOperation(); auto *context = &getContext(); mlir::RewritePatternSet patterns(context); - patterns - .insert( - context); + patterns.insert< + ApplyOpConversion, AsExprOpConversion, AssignOpConversion, + AssociateOpConversion, ConcatOpConversion, DestroyOpConversion, + ElementalOpConversion, EndAssociateOpConversion, MatmulOpConversion, + NoReassocOpConversion, SetLengthOpConversion, SumOpConversion>(context); mlir::ConversionTarget target(*context); target.addIllegalOp> {fir.bindc_name = "lhs"}, %arg1: !fir.box> {fir.bindc_name = "rhs"}, %arg2: !fir.box> {fir.bindc_name = "res"}) { + %0:2 = hlfir.declare %arg0 {uniq_name = "_QFmatmul1Elhs"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %1:2 = hlfir.declare %arg2 {uniq_name = "_QFmatmul1Eres"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %2:2 = hlfir.declare %arg1 {uniq_name = "_QFmatmul1Erhs"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %3 = hlfir.matmul %0#0 %2#0 {fastmath = #arith.fastmath} : (!fir.box>, !fir.box>) -> !hlfir.expr + hlfir.assign %3 to %1#0 : !hlfir.expr, !fir.box> + hlfir.destroy %3 : !hlfir.expr + return +} +// CHECK-LABEL: func.func @_QPmatmul1( +// CHECK: %[[ARG0:.*]]: !fir.box> {fir.bindc_name = "lhs"} +// CHECK: %[[ARG1:.*]]: !fir.box> {fir.bindc_name = "rhs"} +// CHECK: %[[ARG2:.*]]: !fir.box> {fir.bindc_name = "res"} +// CHECK-DAG: %[[LHS_VAR:.*]]:2 = hlfir.declare %[[ARG0]] +// CHECK-DAG: %[[RHS_VAR:.*]]:2 = hlfir.declare %[[ARG1]] +// CHECK-DAG: %[[RES_VAR:.*]]:2 = hlfir.declare %[[ARG2]] + +// CHECK-DAG: %[[RET_BOX:.*]] = fir.alloca !fir.box>> +// CHECK-DAG: %[[RET_ADDR:.*]] = fir.zero_bits !fir.heap> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[RET_SHAPE:.*]] = fir.shape %[[C0]], %[[C0]] : (index, index) -> !fir.shape<2> +// CHECK-DAG: %[[RET_EMBOX:.*]] = fir.embox %[[RET_ADDR]](%[[RET_SHAPE]]) +// CHECK-DAG: fir.store %[[RET_EMBOX]] to %[[RET_BOX]] + +// CHECK: %[[RET_ARG:.*]] = fir.convert %[[RET_BOX]] : (!fir.ref>>>) -> !fir.ref> +// CHECK-DAG: %[[LHS_ARG:.*]] = fir.convert %[[LHS_VAR]]#1 : (!fir.box>) -> !fir.box +// CHECK-DAG: %[[RHS_ARG:.*]] = fir.convert %[[RHS_VAR]]#1 : (!fir.box>) -> !fir.box +// CHECK: %[[NONE:.*]] = fir.call @_FortranAMatmul(%[[RET_ARG]], %[[LHS_ARG]], %[[RHS_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]]) + +// CHECK: %[[RET:.*]] = fir.load %[[RET_BOX]] +// CHECK-DAG: %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[RET]] +// CHECK-DAG: %[[ADDR:.*]] = fir.box_addr %[[RET]] +// CHECK-NEXT: %[[SHIFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1 +// TODO: fix alias analysis in hlfir.assign bufferization +// CHECK-NEXT: %[[TMP:.*]]:2 = hlfir.declare %[[ADDR]](%[[SHIFT]]) {uniq_name = ".tmp.intrinsic_result"} +// CHECK: %[[TUPLE0:.*]] = fir.undefined tuple>, i1> +// CHECK: %[[TUPLE1:.*]] = fir.insert_value %[[TUPLE0]], %[[TRUE:.*]], [1 : index] +// CHECK: %[[TUPLE2:.*]] = fir.insert_value %[[TUPLE1]], %[[TMP]]#0, [0 : index] +// CHECK: hlfir.assign %[[TMP]]#0 to %[[RES_VAR]]#0 +// CHECK: fir.freemem %[[TMP]]#1 +// CHECK-NEXT: return +// CHECK-NEXT: }