diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -60,6 +60,14 @@ private: std::string floatFunc, doubleFunc; }; + +template +void populatePatternsForOp(RewritePatternSet &patterns, MLIRContext *ctx, + StringRef floatFunc, StringRef doubleFunc) { + patterns.add, PromoteOpToF32>(ctx); + patterns.add>(ctx, floatFunc, doubleFunc); +} + } // namespace template @@ -153,35 +161,23 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) { MLIRContext *ctx = patterns.getContext(); - patterns.add, VecOpToScalarOp, - VecOpToScalarOp, VecOpToScalarOp, - VecOpToScalarOp, VecOpToScalarOp, - VecOpToScalarOp, VecOpToScalarOp, - VecOpToScalarOp, VecOpToScalarOp, - VecOpToScalarOp, VecOpToScalarOp>( - ctx); - patterns.add, PromoteOpToF32, - PromoteOpToF32, PromoteOpToF32, - PromoteOpToF32, PromoteOpToF32, - PromoteOpToF32, PromoteOpToF32, - PromoteOpToF32, PromoteOpToF32, - PromoteOpToF32, PromoteOpToF32>(ctx); - patterns.add>(ctx, "atanf", "atan"); - patterns.add>(ctx, "atan2f", "atan2"); - patterns.add>(ctx, "cbrtf", "cbrt"); - patterns.add>(ctx, "erff", "erf"); - patterns.add>(ctx, "expm1f", "expm1"); - patterns.add>(ctx, "tanf", "tan"); - patterns.add>(ctx, "tanhf", "tanh"); - patterns.add>(ctx, "roundevenf", - "roundeven"); - patterns.add>(ctx, "roundf", "round"); - patterns.add>(ctx, "cosf", "cos"); - patterns.add>(ctx, "sinf", "sin"); - patterns.add>(ctx, "log1pf", "log1p"); - patterns.add>(ctx, "floorf", "floor"); - patterns.add>(ctx, "ceilf", "ceil"); - patterns.add>(ctx, "truncf", "trunc"); + + populatePatternsForOp(patterns, ctx, "atan2f", "atan2"); + populatePatternsForOp(patterns, ctx, "atanf", "atan"); + populatePatternsForOp(patterns, ctx, "cbrtf", "cbrt"); + populatePatternsForOp(patterns, ctx, "ceilf", "ceil"); + populatePatternsForOp(patterns, ctx, "cosf", "cos"); + populatePatternsForOp(patterns, ctx, "erff", "erf"); + populatePatternsForOp(patterns, ctx, "expm1f", "expm1"); + populatePatternsForOp(patterns, ctx, "floorf", "floor"); + populatePatternsForOp(patterns, ctx, "log1pf", "log1p"); + populatePatternsForOp(patterns, ctx, "roundevenf", + "roundeven"); + populatePatternsForOp(patterns, ctx, "roundf", "round"); + populatePatternsForOp(patterns, ctx, "sinf", "sin"); + populatePatternsForOp(patterns, ctx, "tanf", "tan"); + populatePatternsForOp(patterns, ctx, "tanhf", "tanh"); + populatePatternsForOp(patterns, ctx, "truncf", "trunc"); } namespace {