diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -790,25 +790,25 @@ patterns.add< ConstantCompositeOpPattern, ConstantScalarOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, RemSIOpGLSLPattern, RemSIOpOCLPattern, BitwiseOpPattern, BitwiseOpPattern, XOrIOpLogicalPattern, XOrIOpBooleanPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, TypeCastingOpPattern, ExtUII1Pattern, TypeCastingOpPattern, TypeCastingOpPattern, diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -64,35 +64,36 @@ RewritePatternSet &patterns) { // GLSL patterns - patterns.add< - Log1pOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern>( - typeConverter, patterns.getContext()); + patterns + .add, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern>( + typeConverter, patterns.getContext()); // OpenCL patterns patterns.add, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern>( + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern>( typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/SPIRVCommon/Pattern.h b/mlir/lib/Conversion/SPIRVCommon/Pattern.h --- a/mlir/lib/Conversion/SPIRVCommon/Pattern.h +++ b/mlir/lib/Conversion/SPIRVCommon/Pattern.h @@ -15,16 +15,17 @@ namespace mlir { namespace spirv { -/// Converts unary and binary standard operations to SPIR-V operations. +/// Converts elementwise unary, binary and ternary standard operations to SPIR-V +/// operations. template -class UnaryAndBinaryOpPattern final : public OpConversionPattern { +class ElementwiseOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - assert(adaptor.getOperands().size() <= 2); + assert(adaptor.getOperands().size() <= 3); auto dstType = this->getTypeConverter()->convertType(op.getType()); if (!dstType) return failure(); diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -230,12 +230,12 @@ patterns.add< // Unary and binary patterns - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, ReturnOpPattern, SelectOpPattern, SplatPattern, BranchOpPattern, CondBranchOpPattern>(typeConverter, context); diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir @@ -68,4 +68,19 @@ return } + // CHECK-LABEL: @float32_ternary_scalar +func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) { + // CHECK: spv.GLSL.Fma %{{.*}}: f32 + %0 = math.fma %a, %b, %c : f32 + return +} + +// CHECK-LABEL: @float32_ternary_vector +func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>, + %c: vector<4xf32>) { + // CHECK: spv.GLSL.Fma %{{.*}}: vector<4xf32> + %0 = math.fma %a, %b, %c : vector<4xf32> + return +} + } // end module