diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -650,8 +650,10 @@ NoSideEffect, AllTypesMatch<["lhs", "rhs", "acc", "result"]>, DeclareOpInterfaceMethods ] # ElementwiseMappable.traits>, - Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc)>, - Results<(outs AnyVector:$result)> { + Arguments<(ins AnyVectorOfAnyRank:$lhs, + AnyVectorOfAnyRank:$rhs, + AnyVectorOfAnyRank:$acc)>, + Results<(outs AnyVectorOfAnyRank:$result)> { let summary = "vector fused multiply-add"; let description = [{ Multiply-add expressions operate on n-D vectors and compute a fused diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -746,7 +746,7 @@ matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType vType = fmaOp.getVectorType(); - if (vType.getRank() != 1) + if (vType.getRank() > 1) return failure(); rewriter.replaceOpWithNewOp( fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1084,7 +1084,7 @@ // ----- -func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>) { +func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>, %d: vector) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>, vector) { // CHECK-LABEL: @vector_fma // CHECK-SAME: %[[A:.*]]: vector<8xf32> // CHECK-SAME: %[[B:.*]]: vector<2x4xf32> @@ -1112,7 +1112,11 @@ // CHECK-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>) -> vector<1xf32> %2 = vector.fma %c, %c, %c : vector<1x1x1xf32> - return %0, %1, %2: vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32> + // CHECK: %[[D0:.*]] = "llvm.intr.fmuladd" + // CHECK-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>) -> vector<1xf32> + %3 = vector.fma %d, %d, %d : vector + + return %0, %1, %2, %3: vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>, vector } // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -478,11 +478,13 @@ } // CHECK-LABEL: @vector_fma -func.func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>) { +func.func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>, %c: vector) { // CHECK: vector.fma %{{.*}} : vector<8xf32> vector.fma %a, %a, %a : vector<8xf32> // CHECK: vector.fma %{{.*}} : vector<8x4xf32> vector.fma %b, %b, %b : vector<8x4xf32> + // CHECK: vector.fma %{{.*}} : vector + vector.fma %c, %c, %c : vector return } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -105,9 +105,19 @@ return } -func.func @reduce_add(%arg0: vector) -> f32 { +func.func @reduce_add(%arg0: vector) { %0 = vector.reduction , %arg0 : vector into f32 - return %0 : f32 + vector.print %0 : f32 + // CHECK: 5 + return +} + +func.func @fma_0d(%four: vector) { + %0 = vector.fma %four, %four, %four : vector + // 4 * 4 + 4 = 20 + // CHECK: ( 20 ) + vector.print %0: vector + return } func.func @entry() { @@ -137,9 +147,10 @@ call @create_mask_0d(%zero_idx, %one_idx) : (index, index) -> () %red_array = arith.constant dense<5.0> : vector - %red_res = call @reduce_add(%red_array) : (vector) -> (f32) - vector.print %red_res : f32 - // CHECK: 5 + call @reduce_add(%red_array) : (vector) -> () + + %5 = arith.constant dense<4.0> : vector + call @fma_0d(%5) : (vector) -> () return }