diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2303,6 +2303,61 @@ } }; +// A `log1p` is converted into `log(1 + ...)`. +struct Log1pOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(math::Log1pOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + math::Log1pOp::Adaptor transformed(operands); + auto operandType = transformed.operand().getType(); + + if (!operandType || !LLVM::isCompatibleType(operandType)) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + auto loc = op.getLoc(); + auto resultType = op.getResult().getType(); + auto floatType = getElementTypeOrSelf(resultType).cast(); + auto floatOne = rewriter.getFloatAttr(floatType, 1.0); + + if (!operandType.isa()) { + LLVM::ConstantOp one = + LLVM::isCompatibleVectorType(operandType) + ? rewriter.create( + loc, operandType, + SplatElementsAttr::get(resultType.cast(), + floatOne)) + : rewriter.create(loc, operandType, floatOne); + + auto add = rewriter.create(loc, operandType, one, + transformed.operand()); + rewriter.replaceOpWithNewOp(op, operandType, add); + return success(); + } + + auto vectorType = resultType.dyn_cast(); + if (!vectorType) + return rewriter.notifyMatchFailure(op, "expected vector result type"); + + return handleMultidimensionalVectors( + op.getOperation(), operands, *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) { + auto splatAttr = SplatElementsAttr::get( + mlir::VectorType::get( + {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, + floatType), + floatOne); + auto one = + rewriter.create(loc, llvm1DVectorTy, splatAttr); + auto add = rewriter.create(loc, llvm1DVectorTy, one, + transformed.operand()); + return rewriter.create(loc, llvm1DVectorTy, add); + }, + rewriter); + } +}; + // A `rsqrt` is converted into `1 / sqrt`. struct RsqrtOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -3788,6 +3843,7 @@ GenericAtomicRMWOpLowering, LogOpLowering, Log10OpLowering, + Log1pOpLowering, Log2OpLowering, FPExtLowering, FPToSILowering, diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -12,6 +12,18 @@ // ----- +// CHECK-LABEL: func @log1p( +// CHECK-SAME: f32 +func @log1p(%arg0 : f32) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 + // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %arg0 : f32 + // CHECK: %[[LOG:.*]] = "llvm.intr.log"(%[[ADD]]) : (f32) -> f32 + %0 = math.log1p %arg0 : f32 + std.return +} + +// ----- + // CHECK-LABEL: func @rsqrt( // CHECK-SAME: f32 func @rsqrt(%arg0 : f32) {