diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -630,8 +630,10 @@ NoSideEffect, AllTypesMatch<["lhs", "rhs", "acc", "result"]>, DeclareOpInterfaceMethods ] # ElementwiseMappable.traits>, - Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc)>, - Results<(outs AnyVector:$result)> { + Arguments<(ins AnyVectorOfAnyRank:$lhs, + AnyVectorOfAnyRank:$rhs, + AnyVectorOfAnyRank:$acc)>, + Results<(outs AnyVectorOfAnyRank:$result)> { let summary = "vector fused multiply-add"; let description = [{ Multiply-add expressions operate on n-D vectors and compute a fused diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -614,7 +614,7 @@ matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType vType = fmaOp.getVectorType(); - if (vType.getRank() != 1) + if (vType.getRank() > 1) return failure(); rewriter.replaceOpWithNewOp(fmaOp, adaptor.lhs(), adaptor.rhs(), adaptor.acc()); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1065,7 +1065,7 @@ // ----- -func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>) { +func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>, %d: vector) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>, vector) { // CHECK-LABEL: @vector_fma // CHECK-SAME: %[[A:.*]]: vector<8xf32> // CHECK-SAME: %[[B:.*]]: vector<2x4xf32> @@ -1093,7 +1093,11 @@ // CHECK-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>) -> vector<1xf32> %2 = vector.fma %c, %c, %c : vector<1x1x1xf32> - return %0, %1, %2: vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32> + // CHECK: %[[D0:.*]] = "llvm.intr.fmuladd" + // CHECK-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>) -> vector<1xf32> + %3 = vector.fma %d, %d, %d : vector + + return %0, %1, %2, %3: vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>, vector } // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -476,11 +476,13 @@ } // CHECK-LABEL: @vector_fma -func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>) { +func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>, %c: vector) { // CHECK: vector.fma %{{.*}} : vector<8xf32> vector.fma %a, %a, %a : vector<8xf32> // CHECK: vector.fma %{{.*}} : vector<8x4xf32> vector.fma %b, %b, %b : vector<8x4xf32> + // CHECK: vector.fma %{{.*}} : vector + vector.fma %c, %c, %c : vector return } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -105,6 +105,14 @@ return } +func @fma_0d(%four: vector) { + %0 = vector.fma %four, %four, %four : vector + // 4 * 4 + 4 = 20 + // CHECK: ( 20 ) + vector.print %0: vector + return +} + func @entry() { %0 = arith.constant 42.0 : f32 %1 = arith.constant dense<0.0> : vector @@ -131,5 +139,8 @@ %one_idx = arith.constant 1 : index call @create_mask_0d(%zero_idx, %one_idx) : (index, index) -> () + %5 = arith.constant dense<4.0> : vector + call @fma_0d(%5) : (vector) -> () + return }