diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1994,6 +1994,78 @@ } }; +// A `tanh` is converted into `(exp(2*x) - 1) / (exp(2*x) + 1)`. +struct TanhOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + OperandAdaptor adaptor(operands); + auto operandType = adaptor.operand().getType().dyn_cast(); + if (!operandType) + return failure(); + + auto loc = op->getLoc(); + auto resultType = *op->result_type_begin(); + auto floatType = getElementTypeOrSelf(resultType).cast(); + auto floatOne = rewriter.getFloatAttr(floatType, 1.0); + auto floatTwo = rewriter.getFloatAttr(floatType, 2.0); + + if (!operandType.isArrayTy()) { + LLVM::ConstantOp one, two; + if (operandType.isVectorTy()) { + one = rewriter.create( + loc, operandType, + SplatElementsAttr::get(resultType.cast(), floatOne)); + two = rewriter.create( + loc, operandType, + SplatElementsAttr::get(resultType.cast(), floatTwo)); + } else { + one = rewriter.create(loc, operandType, floatOne); + two = rewriter.create(loc, operandType, floatTwo); + } + Value xx = rewriter.create(loc, operands[0], two); + Value exp2x = rewriter.create(loc, xx); + Value a = rewriter.create(loc, exp2x, one); + Value b = rewriter.create(loc, exp2x, one); + rewriter.replaceOpWithNewOp(op, a, b); + return success(); + } + + auto vectorType = resultType.dyn_cast(); + if (!vectorType) + return failure(); + + return handleMultidimensionalVectors( + op, operands, typeConverter, + [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) { + auto splatAttrOne = SplatElementsAttr::get( + mlir::VectorType::get( + {cast(llvmVectorTy.getUnderlyingType()) + ->getNumElements()}, + floatType), + floatOne); + auto splatAttrTwo = SplatElementsAttr::get( + mlir::VectorType::get( + {cast(llvmVectorTy.getUnderlyingType()) + ->getNumElements()}, + floatType), + floatTwo); + auto one = rewriter.create(loc, llvmVectorTy, + splatAttrOne); + auto two = rewriter.create(loc, llvmVectorTy, + splatAttrTwo); + Value xx = rewriter.create(loc, operands[0], two); + Value exp2x = rewriter.create(loc, xx); + Value a = rewriter.create(loc, exp2x, one); + Value b = rewriter.create(loc, exp2x, one); + return rewriter.create(loc, a, b); + }, + rewriter); + } +}; + struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -2970,6 +3042,7 @@ SignedRemIOpLowering, SignedShiftRightOpLowering, SinOpLowering, + TanhOpLowering, SplatOpLowering, SplatNdOpLowering, SqrtOpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -78,6 +78,54 @@ // ----- +// CHECK-LABEL: func @tanh( +// CHECK-SAME: !llvm.float +func @tanh(%arg0 : f32) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float + // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2.000000e+00 : f32) : !llvm.float + // CHECK: %[[XX:.*]] = llvm.fmul %arg0, %[[TWO]] : !llvm.float + // CHECK: %[[EXP:.*]] = "llvm.intr.exp"(%[[XX]]) : (!llvm.float) -> !llvm.float + // CHECK: %[[DIVIDEND:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : !llvm.float + // CHECK: %[[DIVISOR:.*]] = llvm.fadd %[[EXP]], %[[ONE]] : !llvm.float + // CHECK: %[[TANH:.*]] = llvm.fdiv %[[DIVIDEND]], %[[DIVISOR]] : !llvm.float + %0 = tanh %arg0 : f32 + std.return +} + +// ----- + +// CHECK-LABEL: func @tanh_double( +// CHECK-SAME: !llvm.double +func @tanh_double(%arg0 : f64) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f64) : !llvm.double + // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2.000000e+00 : f64) : !llvm.double + // CHECK: %[[XX:.*]] = llvm.fmul %arg0, %[[TWO]] : !llvm.double + // CHECK: %[[EXP:.*]] = "llvm.intr.exp"(%[[XX]]) : (!llvm.double) -> !llvm.double + // CHECK: %[[DIVIDEND:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : !llvm.double + // CHECK: %[[DIVISOR:.*]] = llvm.fadd %[[EXP]], %[[ONE]] : !llvm.double + // CHECK: %[[TANH:.*]] = llvm.fdiv %[[DIVIDEND]], %[[DIVISOR]] : !llvm.double + %0 = tanh %arg0 : f64 + std.return +} + +// ----- + +// CHECK-LABEL: func @tanh_vector( +// CHECK-SAME: !llvm<"<4 x float>"> +func @tanh_vector(%arg0 : vector<4xf32>) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : !llvm<"<4 x float>"> + // CHECK: %[[TWO:.*]] = llvm.mlir.constant(dense<2.000000e+00> : vector<4xf32>) : !llvm<"<4 x float>"> + // CHECK: %[[XX:.*]] = llvm.fmul %arg0, %[[TWO]] : !llvm<"<4 x float>"> + // CHECK: %[[EXP:.*]] = "llvm.intr.exp"(%[[XX]]) : (!llvm<"<4 x float>">) -> !llvm<"<4 x float>"> + // CHECK: %[[DIVIDEND:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : !llvm<"<4 x float>"> + // CHECK: %[[DIVISOR:.*]] = llvm.fadd %[[EXP]], %[[ONE]] : !llvm<"<4 x float>"> + // CHECK: %[[TANH:.*]] = llvm.fdiv %[[DIVIDEND]], %[[DIVISOR]] : !llvm<"<4 x float>"> + %0 = tanh %arg0 : vector<4xf32> + std.return +} + +// ----- + // This should not crash. The first operation cannot be converted, so the // second should not match. This attempts to convert `return` to `llvm.return` // and complains about non-LLVM types.