diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2295,6 +2295,60 @@ } }; +// A `rsqrt` is converted into `1 / sqrt`. +struct ExpM1OpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(math::ExpM1Op op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + math::ExpM1Op::Adaptor transformed(operands); + auto operandType = transformed.operand().getType(); + + if (!operandType || !LLVM::isCompatibleType(operandType)) + return failure(); + + auto loc = op.getLoc(); + auto resultType = op.getResult().getType(); + auto floatType = getElementTypeOrSelf(resultType).cast(); + auto floatOne = rewriter.getFloatAttr(floatType, 1.0); + + if (!operandType.isa()) { + LLVM::ConstantOp one; + if (LLVM::isCompatibleVectorType(operandType)) { + one = rewriter.create( + loc, operandType, + SplatElementsAttr::get(resultType.cast(), floatOne)); + } else { + one = rewriter.create(loc, operandType, floatOne); + } + auto exp = rewriter.create(loc, transformed.operand()); + rewriter.replaceOpWithNewOp(op, operandType, exp, one); + return success(); + } + + auto vectorType = resultType.dyn_cast(); + if (!vectorType) + return failure(); + + return handleMultidimensionalVectors( + op.getOperation(), operands, *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) { + auto splatAttr = SplatElementsAttr::get( + mlir::VectorType::get( + {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, + floatType), + floatOne); + auto one = + rewriter.create(loc, llvm1DVectorTy, splatAttr); + auto exp = + rewriter.create(loc, llvm1DVectorTy, operands[0]); + return rewriter.create(loc, llvm1DVectorTy, exp, one); + }, + rewriter); + } +}; + // A `rsqrt` is converted into `1 / sqrt`. struct RsqrtOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -3774,6 +3828,7 @@ DivFOpLowering, ExpOpLowering, Exp2OpLowering, + ExpM1OpLowering, FloorFOpLowering, GenericAtomicRMWOpLowering, LogOpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -12,6 +12,18 @@ // ----- +// CHECK-LABEL: func @expm1( +// CHECK-SAME: f32 +func @expm1(%arg0 : f32) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 + // CHECK: %[[EXP:.*]] = "llvm.intr.exp"(%arg0) : (f32) -> f32 + // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : f32 + %0 = math.expm1 %arg0 : f32 + std.return +} + +// ----- + // CHECK-LABEL: func @rsqrt( // CHECK-SAME: f32 func @rsqrt(%arg0 : f32) {