diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -91,6 +91,21 @@ loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); } +/// Creates `llvm.mlir.constant` with a floating-point scalar or vector value. +static Value createFPConstant(Location loc, Type srcType, Type dstType, + PatternRewriter &rewriter, double value) { + if (auto vecType = srcType.dyn_cast()) { + auto floatType = vecType.getElementType().cast(); + return rewriter.create( + loc, dstType, + SplatElementsAttr::get(vecType, + rewriter.getFloatAttr(floatType, value))); + } + auto floatType = srcType.cast(); + return rewriter.create( + loc, dstType, rewriter.getFloatAttr(floatType, value)); +} + /// Utility function for bitfiled ops: /// - `BitFieldInsert` /// - `BitFieldSExtract` @@ -590,6 +605,27 @@ } }; +class InverseSqrtPattern + : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::GLSLInverseSqrtOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto srcType = op.getType(); + auto dstType = typeConverter.convertType(srcType); + if (!dstType) + return failure(); + + Location loc = op.getLoc(); + Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); + Value sqrt = rewriter.create(loc, dstType, op.operand()); + rewriter.replaceOpWithNewOp(op, dstType, one, sqrt); + return success(); + } +}; + /// Converts `spv.Load` and `spv.Store` to LLVM dialect. template class LoadStorePattern : public SPIRVToLLVMConversion { @@ -821,6 +857,40 @@ } }; +/// Convert `spv.Tanh` to +/// +/// exp(2x) - 1 +/// ----------- +/// exp(2x) + 1 +/// +class TanhPattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::GLSLTanhOp tanhOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto srcType = tanhOp.getType(); + auto dstType = typeConverter.convertType(srcType); + if (!dstType) + return failure(); + + Location loc = tanhOp.getLoc(); + Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); + Value multiplied = + rewriter.create(loc, dstType, two, tanhOp.operand()); + Value exponential = rewriter.create(loc, dstType, multiplied); + Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); + Value numerator = + rewriter.create(loc, dstType, exponential, one); + Value denominator = + rewriter.create(loc, dstType, exponential, one); + rewriter.replaceOpWithNewOp(tanhOp, dstType, numerator, + denominator); + return success(); + } +}; + class VariablePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; @@ -1052,7 +1122,8 @@ DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, - DirectConversionPattern, TanPattern, + DirectConversionPattern, + InverseSqrtPattern, TanPattern, TanhPattern, // Logical ops DirectConversionPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/glsl-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/glsl-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/glsl-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/glsl-ops-to-llvm.mlir @@ -103,3 +103,33 @@ %0 = spv.GLSL.Tan %arg0 : f32 return } + +//===----------------------------------------------------------------------===// +// spv.GLSL.Tanh +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @tanh +func @tanh(%arg0: f32) { + // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2.000000e+00 : f32) : !llvm.float + // CHECK: %[[X2:.*]] = llvm.fmul %[[TWO]], %{{.*}} : !llvm.float + // CHECK: %[[EXP:.*]] = "llvm.intr.exp"(%[[X2]]) : (!llvm.float) -> !llvm.float + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float + // CHECK: %[[T0:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : !llvm.float + // CHECK: %[[T1:.*]] = llvm.fadd %[[EXP]], %[[ONE]] : !llvm.float + // CHECK: llvm.fdiv %[[T0]], %[[T1]] : !llvm.float + %0 = spv.GLSL.Tanh %arg0 : f32 + return +} + +//===----------------------------------------------------------------------===// +// spv.GLSL.InverseSqrt +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @inverse_sqrt +func @inverse_sqrt(%arg0: f32) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float + // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%{{.*}}) : (!llvm.float) -> !llvm.float + // CHECK: llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm.float + %0 = spv.GLSL.InverseSqrt %arg0 : f32 + return +}