diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp --- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp @@ -1,4 +1,4 @@ -//===- LowerHLFIRIntrinsics.cpp - Bufferize HLFIR ------------------------===// +//===- LowerHLFIRIntrinsics.cpp - Transformational intrinsics to FIR ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -37,7 +37,22 @@ /// runtime calls template class HlfirIntrinsicConversion : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; +public: + explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx) + : mlir::OpRewritePattern{ctx} { + // required for cases where intrinsics are chained together e.g. + // matmul(matmul(a, b), c) + // because converting the inner operation then invalidates the + // outer operation: causing the pattern to apply recursively. + // + // This is safe because we always progress with each iteration. Circular + // applications of operations are not expressible in MLIR because we use + // an SSA form and one must become first. E.g. + // %a = hlfir.matmul %b %d + // %b = hlfir.matmul %a %d + // cannot be written. + mlir::OpConversionPattern::setHasBoundedRewriteRecursion(true); + } protected: struct IntrinsicArgument { diff --git a/flang/test/HLFIR/matmul-lowering.fir b/flang/test/HLFIR/matmul-lowering.fir --- a/flang/test/HLFIR/matmul-lowering.fir +++ b/flang/test/HLFIR/matmul-lowering.fir @@ -43,3 +43,39 @@ // CHECK: hlfir.destroy %[[ASEXPR]] // CHECK-NEXT: return // CHECK-NEXT: } + +// nested matmuls leading to recursive pattern application +func.func @_QPtest(%arg0: !fir.ref> {fir.bindc_name = "a"}, %arg1: !fir.ref> {fir.bindc_name = "b"}, %arg2: !fir.ref> {fir.bindc_name = "c"}, %arg3: !fir.ref> {fir.bindc_name = "out"}) { + %c3 = arith.constant 3 : index + %c3_0 = arith.constant 3 : index + %0 = fir.shape %c3, %c3_0 : (index, index) -> !fir.shape<2> + %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFtestEa"} : (!fir.ref>, !fir.shape<2>) -> (!fir.ref>, !fir.ref>) + %c3_1 = arith.constant 3 : index + %c3_2 = arith.constant 3 : index + %2 = fir.shape %c3_1, %c3_2 : (index, index) -> !fir.shape<2> + %3:2 = hlfir.declare %arg1(%2) {uniq_name = "_QFtestEb"} : (!fir.ref>, !fir.shape<2>) -> (!fir.ref>, !fir.ref>) + %c3_3 = arith.constant 3 : index + %c3_4 = arith.constant 3 : index + %4 = fir.shape %c3_3, %c3_4 : (index, index) -> !fir.shape<2> + %5:2 = hlfir.declare %arg2(%4) {uniq_name = "_QFtestEc"} : (!fir.ref>, !fir.shape<2>) -> (!fir.ref>, !fir.ref>) + %c3_5 = arith.constant 3 : index + %c3_6 = arith.constant 3 : index + %6 = fir.shape %c3_5, %c3_6 : (index, index) -> !fir.shape<2> + %7:2 = hlfir.declare %arg3(%6) {uniq_name = "_QFtestEout"} : (!fir.ref>, !fir.shape<2>) -> (!fir.ref>, !fir.ref>) + %8 = hlfir.matmul %1#0 %3#0 {fastmath = #arith.fastmath} : (!fir.ref>, !fir.ref>) -> !hlfir.expr<3x3xf32> + %9 = hlfir.matmul %8 %5#0 {fastmath = #arith.fastmath} : (!hlfir.expr<3x3xf32>, !fir.ref>) -> !hlfir.expr<3x3xf32> + hlfir.assign %9 to %7#0 : !hlfir.expr<3x3xf32>, !fir.ref> + hlfir.destroy %9 : !hlfir.expr<3x3xf32> + hlfir.destroy %8 : !hlfir.expr<3x3xf32> + return +} +// just check that we apply the patterns successfully. The details are checked above +// CHECK-LABEL: func.func @_QPtest( +// CHECK: %arg0: !fir.ref> {fir.bindc_name = "a"}, +// CHECK-SAME: %arg1: !fir.ref> {fir.bindc_name = "b"}, +// CHECK-SAME: %arg2: !fir.ref> {fir.bindc_name = "c"}, +// CHECK-SAME: %arg3: !fir.ref> {fir.bindc_name = "out"}) { +// CHECK: fir.call @_FortranAMatmul( +// CHECK; fir.call @_FortranAMatmul(%40, %41, %42, %43, %c20_i32) : (!fir.ref>, !fir.box, !fir.box, !fir.ref, i32) -> none +// CHECK: return +// CHECK-NEXT: }