diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -30,6 +30,29 @@ // normal RewritePattern. namespace { +/// Converts math.expm1 to SPIR-V ops. +/// +/// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to +/// these operations. +template +class ExpM1OpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(adaptor.getOperands().size() == 1); + Location loc = operation.getLoc(); + auto type = + this->getTypeConverter()->convertType(operation.getOperand().getType()); + auto exp = rewriter.create(loc, type, adaptor.getOperands()[0]); + auto one = spirv::ConstantOp::getOne(type, loc, rewriter); + rewriter.replaceOpWithNewOp(operation, exp, one); + return success(); + } +}; + /// Converts math.log1p to SPIR-V ops. /// /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to @@ -65,7 +88,7 @@ // GLSL patterns patterns - .add, + .add, ExpM1OpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, @@ -81,7 +104,7 @@ typeConverter, patterns.getContext()); // OpenCL patterns - patterns.add, + patterns.add, ExpM1OpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir @@ -8,26 +8,30 @@ %0 = math.cos %arg0 : f32 // CHECK: spv.GLSL.Exp %{{.*}}: f32 %1 = math.exp %arg0 : f32 + // CHECK: %[[EXP:.+]] = spv.GLSL.Exp %arg0 + // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32 + // CHECK: spv.FSub %[[EXP]], %[[ONE]] + %2 = math.expm1 %arg0 : f32 // CHECK: spv.GLSL.Log %{{.*}}: f32 - %2 = math.log %arg0 : f32 + %3 = math.log %arg0 : f32 // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32 // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}} // CHECK: spv.GLSL.Log %[[ADDONE]] - %3 = math.log1p %arg0 : f32 + %4 = math.log1p %arg0 : f32 // CHECK: spv.GLSL.InverseSqrt %{{.*}}: f32 - %4 = math.rsqrt %arg0 : f32 + %5 = math.rsqrt %arg0 : f32 // CHECK: spv.GLSL.Sqrt %{{.*}}: f32 - %5 = math.sqrt %arg0 : f32 + %6 = math.sqrt %arg0 : f32 // CHECK: spv.GLSL.Tanh %{{.*}}: f32 - %6 = math.tanh %arg0 : f32 + %7 = math.tanh %arg0 : f32 // CHECK: spv.GLSL.Sin %{{.*}}: f32 - %7 = math.sin %arg0 : f32 + %8 = math.sin %arg0 : f32 // CHECK: spv.GLSL.FAbs %{{.*}}: f32 - %8 = math.abs %arg0 : f32 + %9 = math.abs %arg0 : f32 // CHECK: spv.GLSL.Ceil %{{.*}}: f32 - %9 = math.ceil %arg0 : f32 + %10 = math.ceil %arg0 : f32 // CHECK: spv.GLSL.Floor %{{.*}}: f32 - %10 = math.floor %arg0 : f32 + %11 = math.floor %arg0 : f32 return } @@ -37,20 +41,24 @@ %0 = math.cos %arg0 : vector<3xf32> // CHECK: spv.GLSL.Exp %{{.*}}: vector<3xf32> %1 = math.exp %arg0 : vector<3xf32> + // CHECK: %[[EXP:.+]] = spv.GLSL.Exp %arg0 + // CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32> + // CHECK: spv.FSub %[[EXP]], %[[ONE]] + %2 = math.expm1 %arg0 : vector<3xf32> // CHECK: spv.GLSL.Log %{{.*}}: vector<3xf32> - %2 = math.log %arg0 : vector<3xf32> + %3 = math.log %arg0 : vector<3xf32> // CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32> // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}} // CHECK: spv.GLSL.Log %[[ADDONE]] - %3 = math.log1p %arg0 : vector<3xf32> + %4 = math.log1p %arg0 : vector<3xf32> // CHECK: spv.GLSL.InverseSqrt %{{.*}}: vector<3xf32> - %4 = math.rsqrt %arg0 : vector<3xf32> + %5 = math.rsqrt %arg0 : vector<3xf32> // CHECK: spv.GLSL.Sqrt %{{.*}}: vector<3xf32> - %5 = math.sqrt %arg0 : vector<3xf32> + %6 = math.sqrt %arg0 : vector<3xf32> // CHECK: spv.GLSL.Tanh %{{.*}}: vector<3xf32> - %6 = math.tanh %arg0 : vector<3xf32> + %7 = math.tanh %arg0 : vector<3xf32> // CHECK: spv.GLSL.Sin %{{.*}}: vector<3xf32> - %7 = math.sin %arg0 : vector<3xf32> + %8 = math.sin %arg0 : vector<3xf32> return } diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir @@ -8,28 +8,32 @@ %0 = math.cos %arg0 : f32 // CHECK: spv.OCL.exp %{{.*}}: f32 %1 = math.exp %arg0 : f32 + // CHECK: %[[EXP:.+]] = spv.OCL.exp %arg0 + // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32 + // CHECK: spv.FSub %[[EXP]], %[[ONE]] + %2 = math.expm1 %arg0 : f32 // CHECK: spv.OCL.log %{{.*}}: f32 - %2 = math.log %arg0 : f32 + %3 = math.log %arg0 : f32 // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32 // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}} // CHECK: spv.OCL.log %[[ADDONE]] - %3 = math.log1p %arg0 : f32 + %4 = math.log1p %arg0 : f32 // CHECK: spv.OCL.rsqrt %{{.*}}: f32 - %4 = math.rsqrt %arg0 : f32 + %5 = math.rsqrt %arg0 : f32 // CHECK: spv.OCL.sqrt %{{.*}}: f32 - %5 = math.sqrt %arg0 : f32 + %6 = math.sqrt %arg0 : f32 // CHECK: spv.OCL.tanh %{{.*}}: f32 - %6 = math.tanh %arg0 : f32 + %7 = math.tanh %arg0 : f32 // CHECK: spv.OCL.sin %{{.*}}: f32 - %7 = math.sin %arg0 : f32 + %8 = math.sin %arg0 : f32 // CHECK: spv.OCL.fabs %{{.*}}: f32 - %8 = math.abs %arg0 : f32 + %9 = math.abs %arg0 : f32 // CHECK: spv.OCL.ceil %{{.*}}: f32 - %9 = math.ceil %arg0 : f32 + %10 = math.ceil %arg0 : f32 // CHECK: spv.OCL.floor %{{.*}}: f32 - %10 = math.floor %arg0 : f32 + %11 = math.floor %arg0 : f32 // CHECK: spv.OCL.erf %{{.*}}: f32 - %11 = math.erf %arg0 : f32 + %12 = math.erf %arg0 : f32 return } @@ -39,20 +43,24 @@ %0 = math.cos %arg0 : vector<3xf32> // CHECK: spv.OCL.exp %{{.*}}: vector<3xf32> %1 = math.exp %arg0 : vector<3xf32> + // CHECK: %[[EXP:.+]] = spv.OCL.exp %arg0 + // CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32> + // CHECK: spv.FSub %[[EXP]], %[[ONE]] + %2 = math.expm1 %arg0 : vector<3xf32> // CHECK: spv.OCL.log %{{.*}}: vector<3xf32> - %2 = math.log %arg0 : vector<3xf32> + %3 = math.log %arg0 : vector<3xf32> // CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32> // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}} // CHECK: spv.OCL.log %[[ADDONE]] - %3 = math.log1p %arg0 : vector<3xf32> + %4 = math.log1p %arg0 : vector<3xf32> // CHECK: spv.OCL.rsqrt %{{.*}}: vector<3xf32> - %4 = math.rsqrt %arg0 : vector<3xf32> + %5 = math.rsqrt %arg0 : vector<3xf32> // CHECK: spv.OCL.sqrt %{{.*}}: vector<3xf32> - %5 = math.sqrt %arg0 : vector<3xf32> + %6 = math.sqrt %arg0 : vector<3xf32> // CHECK: spv.OCL.tanh %{{.*}}: vector<3xf32> - %6 = math.tanh %arg0 : vector<3xf32> + %7 = math.tanh %arg0 : vector<3xf32> // CHECK: spv.OCL.sin %{{.*}}: vector<3xf32> - %7 = math.sin %arg0 : vector<3xf32> + %8 = math.sin %arg0 : vector<3xf32> return }