diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -139,6 +139,7 @@ } }; +template struct VectorFmaOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -148,8 +149,8 @@ Type dstType = getTypeConverter()->convertType(fmaOp.getType()); if (!dstType) return failure(); - rewriter.replaceOpWithNewOp( - fmaOp, dstType, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); + rewriter.replaceOpWithNewOp(fmaOp, dstType, adaptor.getLhs(), + adaptor.getRhs(), adaptor.getAcc()); return success(); } }; @@ -380,9 +381,10 @@ RewritePatternSet &patterns) { patterns.add( - typeConverter, patterns.getContext()); + VectorExtractStridedSliceOpConvert, + VectorFmaOpConvert, + VectorFmaOpConvert, VectorInsertElementOpConvert, + VectorInsertOpConvert, VectorReductionPattern, + VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, + VectorSplatPattern>(typeConverter, patterns.getContext()); } diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -16,6 +16,27 @@ // ----- +module attributes { spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits<>> } { + +// CHECK-LABEL: @cl_fma +// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32> +// CHECK: spv.CL.fma %[[A]], %[[B]], %[[C]] : vector<4xf32> +func.func @cl_fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf32> { + %0 = vector.fma %a, %b, %c: vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: @cl_fma_size1_vector +// CHECK: spv.CL.fma %{{.+}} : f32 +func.func @cl_fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<1xf32>) -> vector<1xf32> { + %0 = vector.fma %a, %b, %c: vector<1xf32> + return %0 : vector<1xf32> +} + +} // end module + +// ----- + // CHECK-LABEL: @broadcast // CHECK-SAME: %[[A:.*]]: f32 // CHECK: spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]]