diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -317,6 +317,28 @@ } }; +/// Converts math.log1p to SPIR-V ops. +/// +/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to +/// these operations. +class Log1pOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(math::Log1pOp operation, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + assert(operands.size() == 1); + Location loc = operation.getLoc(); + auto type = + this->getTypeConverter()->convertType(operation.operand().getType()); + auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); + auto onePlus = rewriter.create(loc, one, operands[0]); + rewriter.replaceOpWithNewOp(operation, type, onePlus); + return success(); + } +}; + /// Converts std.remi_signed to SPIR-V ops. /// /// This cannot be merged into the template unary/binary pattern due to @@ -1347,7 +1369,7 @@ UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, - SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern, + Log1pOpPattern, SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern, // Comparison patterns BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern, diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -53,6 +53,10 @@ %3 = math.exp %arg0 : f32 // CHECK: spv.GLSL.Log %{{.*}}: f32 %4 = math.log %arg0 : f32 + // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32 + // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}} + // CHECK: spv.GLSL.Log %[[ADDONE]] + %40 = math.log1p %arg0 : f32 // CHECK: spv.FNegate %{{.*}}: f32 %5 = negf %arg0 : f32 // CHECK: spv.GLSL.InverseSqrt %{{.*}}: f32