diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -583,8 +583,10 @@ } def Vector_FMAOp : - Op]>, + Op, + DeclareOpInterfaceMethods + ]>, Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc)>, Results<(outs AnyVector:$result)> { let summary = "vector fused multiply-add"; diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1258,6 +1258,14 @@ AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); } +//===----------------------------------------------------------------------===// +// FmaOp +//===----------------------------------------------------------------------===// + +Optional> FMAOp::getShapeForUnroll() { + return llvm::to_vector<4>(getVectorType().getShape()); +} + //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// @@ -2456,8 +2464,7 @@ } Optional> TransferReadOp::getShapeForUnroll() { - auto s = getVectorType().getShape(); - return SmallVector{s.begin(), s.end()}; + return llvm::to_vector<4>(getVectorType().getShape()); } void TransferReadOp::getEffects( diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -73,3 +73,10 @@ // CHECK: vector.contract { // CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16> // CHECK: return + +func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf32>) -> vector<4x4xf32> { + %0 = vector.fma %a, %b, %c: vector<4x4xf32> + return %0 : vector<4x4xf32> +} +// CHECK-LABEL: func @vector_fma +// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32> diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -151,8 +151,9 @@ patterns.insert( ctx, UnrollVectorOptions() .setNativeShape(ArrayRef{2, 2}) - .setFilterConstraint( - [](Operation *op) { return success(isa(op)); })); + .setFilterConstraint([](Operation *op) { + return success(isa(op)); + })); if (unrollBasedOnType) { UnrollVectorOptions::NativeShapeFnType nativeShapeFn =