diff --git a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h --- a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h +++ b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h @@ -21,7 +21,8 @@ #include "mlir/Conversion/Passes.h.inc" void populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns); + RewritePatternSet &patterns, + bool approximateLog1p = true); } // namespace mlir #endif // MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H diff --git a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h --- a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h +++ b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h @@ -8,8 +8,7 @@ #ifndef MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_ #define MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_ -#include "mlir/Transforms/DialectConversion.h" -#include +#include "mlir/IR/PatternMatch.h" namespace mlir { template @@ -20,9 +19,7 @@ /// Populate the given list with patterns that convert from Math to Libm calls. /// If log1pBenefit is present, use it instead of benefit for the Log1p op. -void populateMathToLibmConversionPatterns( - RewritePatternSet &patterns, PatternBenefit benefit, - std::optional log1pBenefit = std::nullopt); +void populateMathToLibmConversionPatterns(RewritePatternSet &patterns); /// Create a pass to convert Math operations to libm calls. std::unique_ptr> createConvertMathToLibmPass(); diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -561,10 +561,11 @@ def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> { let summary = "Convert Math dialect to LLVM dialect"; - let description = [{ - This pass converts supported Math ops to LLVM dialect intrinsics. - }]; let dependentDialects = ["LLVM::LLVMDialect"]; + let options = [ + Option<"approximateLog1p", "approximate-log1p", "bool", "true", + "Enable approximation of Log1p."> + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -291,7 +291,7 @@ void runOnOperation() override { RewritePatternSet patterns(&getContext()); LLVMTypeConverter converter(&getContext()); - populateMathToLLVMConversionPatterns(converter, patterns); + populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -301,7 +301,10 @@ } // namespace void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns) { + RewritePatternSet &patterns, + bool approximateLog1p) { + if (approximateLog1p) + patterns.add(converter); // clang-format off patterns.add< AbsFOpLowering, @@ -319,7 +322,6 @@ FloorOpLowering, FmaOpLowering, Log10OpLowering, - Log1pOpLowering, Log2OpLowering, LogOpLowering, PowFOpLowering, diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -14,11 +14,10 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include +#include "mlir/Transforms/DialectConversion.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOLIBM @@ -52,8 +51,8 @@ public: using OpRewritePattern::OpRewritePattern; ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc, - StringRef doubleFunc, PatternBenefit benefit) - : OpRewritePattern(context, benefit), floatFunc(floatFunc), + StringRef doubleFunc) + : OpRewritePattern(context), floatFunc(floatFunc), doubleFunc(doubleFunc){}; LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; @@ -152,53 +151,37 @@ return success(); } -void mlir::populateMathToLibmConversionPatterns( - RewritePatternSet &patterns, PatternBenefit benefit, - std::optional log1pBenefit) { +void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); patterns.add, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp>( - patterns.getContext(), benefit); + ctx); patterns.add, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, - PromoteOpToF32, PromoteOpToF32>( - patterns.getContext(), benefit); - patterns.add>(patterns.getContext(), "atanf", - "atan", benefit); - patterns.add>(patterns.getContext(), - "atan2f", "atan2", benefit); - patterns.add>(patterns.getContext(), "cbrtf", - "cbrt", benefit); - patterns.add>(patterns.getContext(), "erff", - "erf", benefit); - patterns.add>(patterns.getContext(), - "expm1f", "expm1", benefit); - patterns.add>(patterns.getContext(), "tanf", - "tan", benefit); - patterns.add>(patterns.getContext(), "tanhf", - "tanh", benefit); - patterns.add>( - patterns.getContext(), "roundevenf", "roundeven", benefit); - patterns.add>(patterns.getContext(), - "roundf", "round", benefit); - patterns.add>(patterns.getContext(), "cosf", - "cos", benefit); - patterns.add>(patterns.getContext(), "sinf", - "sin", benefit); - patterns.add>( - patterns.getContext(), "log1pf", "log1p", log1pBenefit.value_or(benefit)); - patterns.add>(patterns.getContext(), - "floorf", "floor", benefit); - patterns.add>(patterns.getContext(), "ceilf", - "ceil", benefit); - patterns.add>(patterns.getContext(), - "truncf", "trunc", benefit); + PromoteOpToF32, PromoteOpToF32>(ctx); + patterns.add>(ctx, "atanf", "atan"); + patterns.add>(ctx, "atan2f", "atan2"); + patterns.add>(ctx, "cbrtf", "cbrt"); + patterns.add>(ctx, "erff", "erf"); + patterns.add>(ctx, "expm1f", "expm1"); + patterns.add>(ctx, "tanf", "tan"); + patterns.add>(ctx, "tanhf", "tanh"); + patterns.add>(ctx, "roundevenf", + "roundeven"); + patterns.add>(ctx, "roundf", "round"); + patterns.add>(ctx, "cosf", "cos"); + patterns.add>(ctx, "sinf", "sin"); + patterns.add>(ctx, "log1pf", "log1p"); + patterns.add>(ctx, "floorf", "floor"); + patterns.add>(ctx, "ceilf", "ceil"); + patterns.add>(ctx, "truncf", "trunc"); } namespace { @@ -212,7 +195,7 @@ auto module = getOperation(); RewritePatternSet patterns(&getContext()); - populateMathToLibmConversionPatterns(patterns, /*benefit=*/1); + populateMathToLibmConversionPatterns(patterns); ConversionTarget target(getContext()); target.addLegalDialect