diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td @@ -972,4 +972,44 @@ }]; } +// ----- + +def SPV_GLSLFmaOp : SPV_GLSLTernaryArithmeticOp<"Fma", 50, SPV_Float> { + let summary = "Computes a * b + c."; + + let description = [{ + In uses where this operation is decorated with NoContraction: + + - fma is considered a single operation, whereas the expression a * b + c + is considered two operations. + - The precision of fma can differ from the precision of the expression + a * b + c. + - fma will be computed with the same precision as any other fma decorated + with NoContraction, giving invariant results for the same input values + of a, b, and c. + + Otherwise, in the absence of a NoContraction decoration, there are no + special constraints on the number of operations or difference in precision + between fma and the expression a * b +c. + + The operands must all be a scalar or vector whose component type is + floating-point. + + Result Type and the type of all operands must be the same type. Results + are computed per component. + + + ``` + fma-op ::= ssa-id `=` `spv.GLSL.Fma` ssa-use, ssa-use, ssa-use `:` + float-scalar-vector-type + ``` + #### Example: + + ```mlir + %0 = spv.GLSL.Fma %a, %b, %c : f32 + %1 = spv.GLSL.Fma %a, %b, %c : vector<3xf16> + ``` + }]; +} + #endif // MLIR_DIALECT_SPIRV_IR_GLSL_OPS diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -36,9 +36,8 @@ vector::BroadcastOp::Adaptor adaptor(operands); SmallVector source(broadcastOp.getVectorType().getNumElements(), adaptor.source()); - Value construct = rewriter.create( - broadcastOp.getLoc(), broadcastOp.getVectorType(), source); - rewriter.replaceOp(broadcastOp, construct); + rewriter.replaceOpWithNewOp( + broadcastOp, broadcastOp.getVectorType(), source); return success(); } }; @@ -55,9 +54,23 @@ return failure(); vector::ExtractOp::Adaptor adaptor(operands); int32_t id = extractOp.position().begin()->cast().getInt(); - Value newExtract = rewriter.create( - extractOp.getLoc(), adaptor.vector(), id); - rewriter.replaceOp(extractOp, newExtract); + rewriter.replaceOpWithNewOp( + extractOp, adaptor.vector(), id); + return success(); + } +}; + +struct VectorFmaOpConvert final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::FMAOp fmaOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!spirv::CompositeType::isValid(fmaOp.getVectorType())) + return failure(); + vector::FMAOp::Adaptor adaptor(operands); + rewriter.replaceOpWithNewOp( + fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc()); return success(); } }; @@ -74,9 +87,8 @@ return failure(); vector::InsertOp::Adaptor adaptor(operands); int32_t id = insertOp.position().begin()->cast().getInt(); - Value newInsert = rewriter.create( - insertOp.getLoc(), adaptor.source(), adaptor.dest(), id); - rewriter.replaceOp(insertOp, newInsert); + rewriter.replaceOpWithNewOp( + insertOp, adaptor.source(), adaptor.dest(), id); return success(); } }; @@ -92,10 +104,9 @@ if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) return failure(); vector::ExtractElementOp::Adaptor adaptor(operands); - Value newExtractElement = rewriter.create( - extractElementOp.getLoc(), extractElementOp.getType(), adaptor.vector(), + rewriter.replaceOpWithNewOp( + extractElementOp, extractElementOp.getType(), adaptor.vector(), extractElementOp.position()); - rewriter.replaceOp(extractElementOp, newExtractElement); return success(); } }; @@ -111,10 +122,9 @@ if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) return failure(); vector::InsertElementOp::Adaptor adaptor(operands); - Value newInsertElement = rewriter.create( - insertElementOp.getLoc(), insertElementOp.getType(), - insertElementOp.dest(), adaptor.source(), insertElementOp.position()); - rewriter.replaceOp(insertElementOp, newInsertElement); + rewriter.replaceOpWithNewOp( + insertElementOp, insertElementOp.getType(), insertElementOp.dest(), + adaptor.source(), insertElementOp.position()); return success(); } }; @@ -124,7 +134,8 @@ void mlir::populateVectorToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - patterns.insert(typeConverter, context); + patterns.insert( + typeConverter, context); } diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir --- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -57,3 +57,13 @@ %0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32> spv.Return } + +// ----- + +// CHECK-LABEL: func @fma +// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32> +// CHECK: spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32> +func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) { + %0 = vector.fma %a, %b, %c: vector<4xf32> + spv.Return +} diff --git a/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir @@ -345,3 +345,23 @@ %2 = spv.GLSL.SClamp %arg0, %min, %max : i32 return } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.GLSL.Fma +//===----------------------------------------------------------------------===// + +func @fma(%a : f32, %b : f32, %c : f32) -> () { + // CHECK: spv.GLSL.Fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32 + %2 = spv.GLSL.Fma %a, %b, %c : f32 + return +} + +// ----- + +func @fma(%a : vector<3xf32>, %b : vector<3xf32>, %c : vector<3xf32>) -> () { + // CHECK: spv.GLSL.Fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : vector<3xf32> + %2 = spv.GLSL.Fma %a, %b, %c : vector<3xf32> + return +} diff --git a/mlir/test/Target/SPIRV/glsl-ops.mlir b/mlir/test/Target/SPIRV/glsl-ops.mlir --- a/mlir/test/Target/SPIRV/glsl-ops.mlir +++ b/mlir/test/Target/SPIRV/glsl-ops.mlir @@ -48,4 +48,10 @@ %13 = spv.GLSL.SClamp %arg0, %arg1, %arg2 : si32 spv.Return } + + spv.func @fma(%arg0 : f32, %arg1 : f32, %arg2 : f32) "None" { + // CHECK: spv.GLSL.Fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32 + %13 = spv.GLSL.Fma %arg0, %arg1, %arg2 : f32 + spv.Return + } }