diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -752,6 +752,12 @@ public: using OpRewritePattern::OpRewritePattern; + void initialize() { + // This pattern recursively unpacks one dimension at a time. The recursion + // bounded as the rank is strictly decreasing. + setHasBoundedRewriteRecursion(); + } + LogicalResult matchAndRewrite(FMAOp op, PatternRewriter &rewriter) const override { auto vType = op.getVectorType(); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -941,10 +941,11 @@ // ----- -func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vector<2x4xf32>) { +func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>) { // CHECK-LABEL: @vector_fma // CHECK-SAME: %[[A:.*]]: vector<8xf32> // CHECK-SAME: %[[B:.*]]: vector<2x4xf32> + // CHECK-SAME: %[[C:.*]]: vector<1x1x1xf32> // CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>> // CHECK: "llvm.intr.fmuladd" // CHECK-SAME: (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32> @@ -964,7 +965,11 @@ // CHECK: llvm.insertvalue %[[B1]], {{.*}}[1] : !llvm.array<2 x vector<4xf32>> %1 = vector.fma %b, %b, %b : vector<2x4xf32> - return %0, %1: vector<8xf32>, vector<2x4xf32> + // CHECK: %[[C0:.*]] = "llvm.intr.fmuladd" + // CHECK-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>) -> vector<1xf32> + %2 = vector.fma %c, %c, %c : vector<1x1x1xf32> + + return %0, %1, %2: vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32> } // -----