diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -60,6 +60,14 @@ private: std::string floatFunc, doubleFunc; }; + +template <typename OpTy> +void populatePatternsForOp(RewritePatternSet &patterns, MLIRContext *ctx, + StringRef floatFunc, StringRef doubleFunc) { + patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx); + patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc); +} + } // namespace template <typename Op> @@ -153,35 +161,23 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) { MLIRContext *ctx = patterns.getContext(); - patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::CbrtOp>, - VecOpToScalarOp<math::ExpM1Op>, VecOpToScalarOp<math::TanhOp>, - VecOpToScalarOp<math::CosOp>, VecOpToScalarOp<math::SinOp>, - VecOpToScalarOp<math::ErfOp>, VecOpToScalarOp<math::RoundEvenOp>, - VecOpToScalarOp<math::RoundOp>, VecOpToScalarOp<math::AtanOp>, - VecOpToScalarOp<math::TanOp>, VecOpToScalarOp<math::TruncOp>>( - ctx); - patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::CbrtOp>, - PromoteOpToF32<math::ExpM1Op>, PromoteOpToF32<math::TanhOp>, - PromoteOpToF32<math::CosOp>, PromoteOpToF32<math::SinOp>, - PromoteOpToF32<math::ErfOp>, PromoteOpToF32<math::RoundEvenOp>, - PromoteOpToF32<math::RoundOp>, PromoteOpToF32<math::AtanOp>, - PromoteOpToF32<math::TanOp>, PromoteOpToF32<math::TruncOp>>(ctx); - patterns.add<ScalarOpToLibmCall<math::AtanOp>>(ctx, "atanf", "atan"); - patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(ctx, "atan2f", "atan2"); - patterns.add<ScalarOpToLibmCall<math::CbrtOp>>(ctx, "cbrtf", "cbrt"); - patterns.add<ScalarOpToLibmCall<math::ErfOp>>(ctx, "erff", "erf"); - patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(ctx, "expm1f", "expm1"); - patterns.add<ScalarOpToLibmCall<math::TanOp>>(ctx, "tanf", "tan"); - patterns.add<ScalarOpToLibmCall<math::TanhOp>>(ctx, "tanhf", "tanh"); - patterns.add<ScalarOpToLibmCall<math::RoundEvenOp>>(ctx, "roundevenf", - "roundeven"); - patterns.add<ScalarOpToLibmCall<math::RoundOp>>(ctx, "roundf", "round"); - patterns.add<ScalarOpToLibmCall<math::CosOp>>(ctx, "cosf", "cos"); - patterns.add<ScalarOpToLibmCall<math::SinOp>>(ctx, "sinf", "sin"); - patterns.add<ScalarOpToLibmCall<math::Log1pOp>>(ctx, "log1pf", "log1p"); - patterns.add<ScalarOpToLibmCall<math::FloorOp>>(ctx, "floorf", "floor"); - patterns.add<ScalarOpToLibmCall<math::CeilOp>>(ctx, "ceilf", "ceil"); - patterns.add<ScalarOpToLibmCall<math::TruncOp>>(ctx, "truncf", "trunc"); + + populatePatternsForOp<math::Atan2Op>(patterns, ctx, "atan2f", "atan2"); + populatePatternsForOp<math::AtanOp>(patterns, ctx, "atanf", "atan"); + populatePatternsForOp<math::CbrtOp>(patterns, ctx, "cbrtf", "cbrt"); + populatePatternsForOp<math::CeilOp>(patterns, ctx, "ceilf", "ceil"); + populatePatternsForOp<math::CosOp>(patterns, ctx, "cosf", "cos"); + populatePatternsForOp<math::ErfOp>(patterns, ctx, "erff", "erf"); + populatePatternsForOp<math::ExpM1Op>(patterns, ctx, "expm1f", "expm1"); + populatePatternsForOp<math::FloorOp>(patterns, ctx, "floorf", "floor"); + populatePatternsForOp<math::Log1pOp>(patterns, ctx, "log1pf", "log1p"); + populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, "roundevenf", + "roundeven"); + populatePatternsForOp<math::RoundOp>(patterns, ctx, "roundf", "round"); + populatePatternsForOp<math::SinOp>(patterns, ctx, "sinf", "sin"); + populatePatternsForOp<math::TanOp>(patterns, ctx, "tanf", "tan"); + populatePatternsForOp<math::TanhOp>(patterns, ctx, "tanhf", "tanh"); + populatePatternsForOp<math::TruncOp>(patterns, ctx, "truncf", "trunc"); } namespace {