diff --git a/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h --- a/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h +++ b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRV.h @@ -18,10 +18,17 @@ namespace mlir { class SPIRVTypeConverter; +/// Type of SPIRV ops generated from math ops. +enum class MathToSPIRVConversionType { + GLSL, + OpenCL, +}; + /// Appends to a pattern list additional patterns for translating Math ops /// to SPIR-V ops. void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, - RewritePatternSet &patterns); + RewritePatternSet &patterns, + MathToSPIRVConversionType type); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h --- a/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h +++ b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h @@ -17,8 +17,11 @@ namespace mlir { -/// Creates a pass to convert Math ops to SPIR-V ops. -std::unique_ptr> createConvertMathToSPIRVPass(); +/// Creates a pass to convert Math ops to GLSL SPIR-V ops. +std::unique_ptr> createConvertMathToGlslSPIRVPass(); + +/// Creates a pass to convert Math ops to OpenCL SPIR-V ops. +std::unique_ptr> createConvertMathToOpenclSPIRVPass(); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -313,9 +313,15 @@ // MathToSPIRV //===----------------------------------------------------------------------===// -def ConvertMathToSPIRV : Pass<"convert-math-to-spirv", "ModuleOp"> { - let summary = "Convert Math dialect to SPIR-V dialect"; - let constructor = "mlir::createConvertMathToSPIRVPass()"; +def ConvertMathToGlslSPIRV : Pass<"convert-math-to-glsl-spirv", "ModuleOp"> { + let summary = "Convert Math dialect to GLSL SPIR-V ops"; + let constructor = "mlir::createConvertMathToGlslSPIRVPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; +} + +def ConvertMathToOpenclSPIRV : Pass<"convert-math-to-opencl-spirv", "ModuleOp"> { + let summary = "Convert Math dialect to OpenCL SPIR-V ops"; + let constructor = "mlir::createConvertMathToOpenclSPIRVPass()"; let dependentDialects = ["spirv::SPIRVDialect"]; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td @@ -78,6 +78,69 @@ // ----- +def SPV_OCLTanhOp : SPV_OCLUnaryArithmeticOp<"tanh", 63, SPV_Float> { + let summary = "Compute hyperbolic tangent of x radians."; + + let description = [{ + Result Type and x must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + tanh-op ::= ssa-id `=` `spv.OCL.tanh` ssa-use `:` + float-scalar-vector-type + ```mlir + + #### Example: + + ``` + %2 = spv.OCL.tanh %0 : f32 + %3 = spv.OCL.tanh %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_OCLCeilOp : SPV_OCLUnaryArithmeticOp<"ceil", 12, SPV_Float> { + let summary = [{ + Round x to integral value using the round to positive infinity rounding + mode. + }]; + + let description = [{ + Result Type and x must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + ceil-op ::= ssa-id `=` `spv.OCL.ceil` ssa-use `:` + float-scalar-vector-type + ```mlir + + #### Example: + + ``` + %2 = spv.OCL.ceil %0 : f32 + %3 = spv.OCL.ceil %1 : vector<3xf16> + ``` + }]; +} + +// ----- + def SPV_OCLCosOp : SPV_OCLUnaryArithmeticOp<"cos", 14, SPV_Float> { let summary = "Compute the cosine of x radians."; @@ -93,7 +156,7 @@ ``` float-scalar-vector-type ::= float-type | `vector<` integer-literal `x` float-type `>` - abs-op ::= ssa-id `=` `spv.OCL.cos` ssa-use `:` + cos-op ::= ssa-id `=` `spv.OCL.cos` ssa-use `:` float-scalar-vector-type ```mlir @@ -168,6 +231,39 @@ // ----- +def SPV_OCLFloorOp : SPV_OCLUnaryArithmeticOp<"floor", 25, SPV_Float> { + let summary = [{ + Round x to the integral value using the round to negative infinity + rounding mode. + }]; + + let description = [{ + Result Type and x must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + floor-op ::= ssa-id `=` `spv.OCL.floor` ssa-use `:` + float-scalar-vector-type + ```mlir + + #### Example: + + ``` + %2 = spv.OCL.floor %0 : f32 + %3 = spv.OCL.ceifloorl %1 : vector<3xf16> + ``` + }]; +} + +// ----- + def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> { let summary = "Compute the natural logarithm of x."; @@ -183,7 +279,7 @@ ``` float-scalar-vector-type ::= float-type | `vector<` integer-literal `x` float-type `>` - abs-op ::= ssa-id `=` `spv.OCL.log` ssa-use `:` + log-op ::= ssa-id `=` `spv.OCL.log` ssa-use `:` float-scalar-vector-type ```mlir @@ -198,6 +294,67 @@ // ----- +def SPV_OCLPowOp : SPV_OCLBinaryArithmeticOp<"pow", 48, SPV_Float> { + let summary = "Compute x to the power y."; + + let description = [{ + Result Type, x and y must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + + + ``` + restricted-float-scalar-type ::= `f16` | `f32` + restricted-float-scalar-vector-type ::= + restricted-float-scalar-type | + `vector<` integer-literal `x` restricted-float-scalar-type `>` + pow-op ::= ssa-id `=` `spv.OCL.pow` ssa-use `:` + restricted-float-scalar-vector-type + ``` + #### Example: + + ```mlir + %2 = spv.OCL.pow %0, %1 : f32 + %3 = spv.OCL.pow %0, %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_OCLRsqrtOp : SPV_OCLUnaryArithmeticOp<"rsqrt", 56, SPV_Float> { + let summary = "Compute inverse square root of x."; + + let description = [{ + Result Type and x must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + rsqrt-op ::= ssa-id `=` `spv.OCL.rsqrt` ssa-use `:` + float-scalar-vector-type + ```mlir + + #### Example: + + ``` + %2 = spv.OCL.rsqrt %0 : f32 + %3 = spv.OCL.rsqrt %1 : vector<3xf16> + ``` + }]; +} + +// ----- + def SPV_OCLSinOp : SPV_OCLUnaryArithmeticOp<"sin", 57, SPV_Float> { let summary = "Compute sine of x radians."; @@ -213,7 +370,7 @@ ``` float-scalar-vector-type ::= float-type | `vector<` integer-literal `x` float-type `>` - abs-op ::= ssa-id `=` `spv.OCL.sin` ssa-use `:` + sin-op ::= ssa-id `=` `spv.OCL.sin` ssa-use `:` float-scalar-vector-type ```mlir @@ -243,7 +400,7 @@ ``` float-scalar-vector-type ::= float-type | `vector<` integer-literal `x` float-type `>` - abs-op ::= ssa-id `=` `spv.OCL.sqrt` ssa-use `:` + sqrt-op ::= ssa-id `=` `spv.OCL.sqrt` ssa-use `:` float-scalar-vector-type ```mlir diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h" #include "../SPIRVCommon/Pattern.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" @@ -34,6 +35,7 @@ /// /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to /// these operations. +template class Log1pOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -48,7 +50,7 @@ auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); auto onePlus = rewriter.create(loc, one, adaptor.getOperands()[0]); - rewriter.replaceOpWithNewOp(operation, type, onePlus); + rewriter.replaceOpWithNewOp(operation, type, onePlus); return success(); } }; @@ -60,21 +62,41 @@ namespace mlir { void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns.add< - Log1pOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern>( - typeConverter, patterns.getContext()); + RewritePatternSet &patterns, + MathToSPIRVConversionType type) { + if (type == MathToSPIRVConversionType::GLSL) { + patterns.add< + Log1pOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern>( + typeConverter, patterns.getContext()); + } else if (type == MathToSPIRVConversionType::OpenCL) { + patterns + .add, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern>( + typeConverter, patterns.getContext()); + } else { + llvm_unreachable("Invalid MathToSPIRVConversionType"); + } } } // namespace mlir diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp @@ -19,14 +19,38 @@ using namespace mlir; namespace { -/// A pass converting MLIR Math operations into the SPIR-V dialect. -class ConvertMathToSPIRVPass - : public ConvertMathToSPIRVBase { +/// A pass converting MLIR Math operations into the GLSL SPIR-V ops. +class ConvertMathToGlslSPIRVPass + : public ConvertMathToGlslSPIRVBase { + void runOnOperation() override; +}; + +/// A pass converting MLIR Math operations into the OpenCL SPIR-V ops. +class ConvertMathToOpenclSPIRVPass + : public ConvertMathToOpenclSPIRVBase { void runOnOperation() override; }; } // namespace -void ConvertMathToSPIRVPass::runOnOperation() { +void ConvertMathToGlslSPIRVPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + auto targetAttr = spirv::lookupTargetEnvOrDefault(module); + std::unique_ptr target = + SPIRVConversionTarget::get(targetAttr); + + SPIRVTypeConverter typeConverter(targetAttr); + + RewritePatternSet patterns(context); + populateMathToSPIRVPatterns(typeConverter, patterns, + MathToSPIRVConversionType::GLSL); + + if (failed(applyPartialConversion(module, *target, std::move(patterns)))) + return signalPassFailure(); +} + +void ConvertMathToOpenclSPIRVPass::runOnOperation() { MLIRContext *context = &getContext(); ModuleOp module = getOperation(); @@ -37,12 +61,19 @@ SPIRVTypeConverter typeConverter(targetAttr); RewritePatternSet patterns(context); - populateMathToSPIRVPatterns(typeConverter, patterns); + populateMathToSPIRVPatterns(typeConverter, patterns, + MathToSPIRVConversionType::OpenCL); if (failed(applyPartialConversion(module, *target, std::move(patterns)))) return signalPassFailure(); } -std::unique_ptr> mlir::createConvertMathToSPIRVPass() { - return std::make_unique(); +std::unique_ptr> +mlir::createConvertMathToGlslSPIRVPass() { + return std::make_unique(); +} + +std::unique_ptr> +mlir::createConvertMathToOpenclSPIRVPass() { + return std::make_unique(); } diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp @@ -43,7 +43,8 @@ // TODO ArithmeticToSPIRV cannot be applied separately to StandardToSPIRV RewritePatternSet patterns(context); arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); - populateMathToSPIRVPatterns(typeConverter, patterns); + populateMathToSPIRVPatterns(typeConverter, patterns, + MathToSPIRVConversionType::GLSL); populateStandardToSPIRVPatterns(typeConverter, patterns); populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64, patterns); diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir rename from mlir/test/Conversion/MathToSPIRV/math-to-spirv.mlir rename to mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s +// RUN: mlir-opt -split-input-file -convert-math-to-glsl-spirv -verify-diagnostics %s -o - | FileCheck %s // CHECK-LABEL: @float32_unary_scalar func @float32_unary_scalar(%arg0: f32) { diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir rename from mlir/test/Conversion/MathToSPIRV/math-to-spirv.mlir rename to mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir @@ -1,61 +1,61 @@ -// RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s +// RUN: mlir-opt -split-input-file -convert-math-to-opencl-spirv -verify-diagnostics %s -o - | FileCheck %s // CHECK-LABEL: @float32_unary_scalar func @float32_unary_scalar(%arg0: f32) { - // CHECK: spv.GLSL.Cos %{{.*}}: f32 + // CHECK: spv.OCL.cos %{{.*}}: f32 %0 = math.cos %arg0 : f32 - // CHECK: spv.GLSL.Exp %{{.*}}: f32 + // CHECK: spv.OCL.exp %{{.*}}: f32 %1 = math.exp %arg0 : f32 - // CHECK: spv.GLSL.Log %{{.*}}: f32 + // CHECK: spv.OCL.log %{{.*}}: f32 %2 = math.log %arg0 : f32 // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32 // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}} - // CHECK: spv.GLSL.Log %[[ADDONE]] + // CHECK: spv.OCL.log %[[ADDONE]] %3 = math.log1p %arg0 : f32 - // CHECK: spv.GLSL.InverseSqrt %{{.*}}: f32 + // CHECK: spv.OCL.rsqrt %{{.*}}: f32 %4 = math.rsqrt %arg0 : f32 - // CHECK: spv.GLSL.Sqrt %{{.*}}: f32 + // CHECK: spv.OCL.sqrt %{{.*}}: f32 %5 = math.sqrt %arg0 : f32 - // CHECK: spv.GLSL.Tanh %{{.*}}: f32 + // CHECK: spv.OCL.tanh %{{.*}}: f32 %6 = math.tanh %arg0 : f32 - // CHECK: spv.GLSL.Sin %{{.*}}: f32 + // CHECK: spv.OCL.sin %{{.*}}: f32 %7 = math.sin %arg0 : f32 return } // CHECK-LABEL: @float32_unary_vector func @float32_unary_vector(%arg0: vector<3xf32>) { - // CHECK: spv.GLSL.Cos %{{.*}}: vector<3xf32> + // CHECK: spv.OCL.cos %{{.*}}: vector<3xf32> %0 = math.cos %arg0 : vector<3xf32> - // CHECK: spv.GLSL.Exp %{{.*}}: vector<3xf32> + // CHECK: spv.OCL.exp %{{.*}}: vector<3xf32> %1 = math.exp %arg0 : vector<3xf32> - // CHECK: spv.GLSL.Log %{{.*}}: vector<3xf32> + // CHECK: spv.OCL.log %{{.*}}: vector<3xf32> %2 = math.log %arg0 : vector<3xf32> // CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32> // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}} - // CHECK: spv.GLSL.Log %[[ADDONE]] + // CHECK: spv.OCL.log %[[ADDONE]] %3 = math.log1p %arg0 : vector<3xf32> - // CHECK: spv.GLSL.InverseSqrt %{{.*}}: vector<3xf32> + // CHECK: spv.OCL.rsqrt %{{.*}}: vector<3xf32> %4 = math.rsqrt %arg0 : vector<3xf32> - // CHECK: spv.GLSL.Sqrt %{{.*}}: vector<3xf32> + // CHECK: spv.OCL.sqrt %{{.*}}: vector<3xf32> %5 = math.sqrt %arg0 : vector<3xf32> - // CHECK: spv.GLSL.Tanh %{{.*}}: vector<3xf32> + // CHECK: spv.OCL.tanh %{{.*}}: vector<3xf32> %6 = math.tanh %arg0 : vector<3xf32> - // CHECK: spv.GLSL.Sin %{{.*}}: vector<3xf32> + // CHECK: spv.OCL.sin %{{.*}}: vector<3xf32> %7 = math.sin %arg0 : vector<3xf32> return } // CHECK-LABEL: @float32_binary_scalar func @float32_binary_scalar(%lhs: f32, %rhs: f32) { - // CHECK: spv.GLSL.Pow %{{.*}}: f32 + // CHECK: spv.OCL.pow %{{.*}}: f32 %0 = math.powf %lhs, %rhs : f32 return } // CHECK-LABEL: @float32_binary_vector func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) { - // CHECK: spv.GLSL.Pow %{{.*}}: vector<4xf32> + // CHECK: spv.OCL.pow %{{.*}}: vector<4xf32> %0 = math.powf %lhs, %rhs : vector<4xf32> return } diff --git a/mlir/test/Target/SPIRV/ocl-ops.mlir b/mlir/test/Target/SPIRV/ocl-ops.mlir --- a/mlir/test/Target/SPIRV/ocl-ops.mlir +++ b/mlir/test/Target/SPIRV/ocl-ops.mlir @@ -14,6 +14,14 @@ %4 = spv.OCL.log %arg0 : f32 // CHECK: {{%.*}} = spv.OCL.sqrt {{%.*}} : f32 %5 = spv.OCL.sqrt %arg0 : f32 + // CHECK: {{%.*}} = spv.OCL.ceil {{%.*}} : f32 + %6 = spv.OCL.ceil %arg0 : f32 + // CHECK: {{%.*}} = spv.OCL.floor {{%.*}} : f32 + %7 = spv.OCL.floor %arg0 : f32 + // CHECK: {{%.*}} = spv.OCL.pow {{%.*}}, {{%.*}} : f32 + %8 = spv.OCL.pow %arg0, %arg0 : f32 + // CHECK: {{%.*}} = spv.OCL.rsqrt {{%.*}} : f32 + %9 = spv.OCL.rsqrt %arg0 : f32 spv.Return }