diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -107,16 +107,16 @@ namespace { -/// Converts binary standard operations to SPIR-V operations. +/// Converts unary and binary standard operations to SPIR-V operations. template -class BinaryOpPattern final : public SPIRVOpLowering { +class UnaryAndBinaryOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(StdOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - assert(operands.size() == 2); + assert(operands.size() <= 2); auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); @@ -572,21 +572,31 @@ SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert< - BinaryOpPattern, - BinaryOpPattern, - BinaryOpPattern, - BinaryOpPattern, - BinaryOpPattern, - BinaryOpPattern, - BinaryOpPattern, - BinaryOpPattern, - BinaryOpPattern, - BinaryOpPattern, - BinaryOpPattern, - BinaryOpPattern, - BinaryOpPattern, - BinaryOpPattern, - BinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, + UnaryAndBinaryOpPattern, BitwiseOpPattern, BitwiseOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern, CmpFOpPattern, diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -31,9 +31,33 @@ return } -// Check float operation conversions. -// CHECK-LABEL: @float32_scalar -func @float32_scalar(%lhs: f32, %rhs: f32) { +// Check float unary operation conversions. +// CHECK-LABEL: @float32_unary_scalar +func @float32_unary_scalar(%arg0: f32) { + // CHECK: spv.GLSL.FAbs %{{.*}}: f32 + %0 = absf %arg0 : f32 + // CHECK: spv.GLSL.Ceil %{{.*}}: f32 + %1 = ceilf %arg0 : f32 + // CHECK: spv.GLSL.Cos %{{.*}}: f32 + %2 = cos %arg0 : f32 + // CHECK: spv.GLSL.Exp %{{.*}}: f32 + %3 = exp %arg0 : f32 + // CHECK: spv.GLSL.Log %{{.*}}: f32 + %4 = log %arg0 : f32 + // CHECK: spv.FNegate %{{.*}}: f32 + %5 = negf %arg0 : f32 + // CHECK: spv.GLSL.InverseSqrt %{{.*}}: f32 + %6 = rsqrt %arg0 : f32 + // CHECK: spv.GLSL.Sqrt %{{.*}}: f32 + %7 = sqrt %arg0 : f32 + // CHECK: spv.GLSL.Tanh %{{.*}}: f32 + %8 = tanh %arg0 : f32 + return +} + +// Check float binary operation conversions. +// CHECK-LABEL: @float32_binary_scalar +func @float32_binary_scalar(%lhs: f32, %rhs: f32) { // CHECK: spv.FAdd %{{.*}}, %{{.*}}: f32 %0 = addf %lhs, %rhs: f32 // CHECK: spv.FSub %{{.*}}, %{{.*}}: f32