diff --git a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp --- a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp +++ b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp @@ -18,12 +18,14 @@ namespace { // Pattern to convert scalar complex operations to calls to libm functions. // Additionally the libm function signatures are declared. -template +template struct ScalarOpToLibmCall : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc, - StringRef doubleFunc, PatternBenefit benefit) + ScalarOpToLibmCall(MLIRContext *context, + StringRef floatFunc, + StringRef doubleFunc, + PatternBenefit benefit) : OpRewritePattern(context, benefit), floatFunc(floatFunc), doubleFunc(doubleFunc){}; @@ -32,20 +34,41 @@ private: std::string floatFunc, doubleFunc; }; + +// Functor to resolve the function name corresponding to the given complex type. +struct ComplexTypeFuncResolver { + std::string operator()(Type type, std::string doubleFunc, + std::string floatFunc) const { + auto complexType = type.cast(); + auto elementType = complexType.getElementType(); + if (!elementType.isa()) + return nullptr; + + return elementType.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; + } +}; + +// Functor to resolve the function name corresponding to the given float type. +struct FloatTypeFuncResolver { + std::string operator()(Type type, std::string doubleFunc, + std::string floatFunc) const { + auto elementType = type.cast(); + if (!elementType.isa()) + return nullptr; + + return elementType.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; + } +}; } // namespace -template -LogicalResult -ScalarOpToLibmCall::matchAndRewrite(Op op, - PatternRewriter &rewriter) const { +template +LogicalResult ScalarOpToLibmCall::matchAndRewrite( + Op op, PatternRewriter &rewriter) const { auto module = SymbolTable::getNearestSymbolTable(op); - auto type = op.getType().template cast(); - Type elementType = type.getElementType(); - if (!elementType.isa()) + auto name = FuncResolver()(op.getType(), doubleFunc, floatFunc); + if (name.empty()) return failure(); - auto name = - elementType.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; auto opFunc = dyn_cast_or_null( SymbolTable::lookupSymbolIn(module, name)); // Forward declare function if it hasn't already been @@ -60,25 +83,28 @@ } assert(isa(SymbolTable::lookupSymbolIn(module, name))); - rewriter.replaceOpWithNewOp(op, name, type, op->getOperands()); + rewriter.replaceOpWithNewOp(op, name, op.getType(), + op->getOperands()); return success(); } void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add>(patterns.getContext(), - "cpowf", "cpow", benefit); - patterns.add>(patterns.getContext(), - "csqrtf", "csqrt", benefit); - patterns.add>(patterns.getContext(), - "ctanhf", "ctanh", benefit); - patterns.add>(patterns.getContext(), - "ccosf", "ccos", benefit); - patterns.add>(patterns.getContext(), - "csinf", "csin", benefit); - patterns.add>(patterns.getContext(), - "conjf", "conj", benefit); + patterns.add>( + patterns.getContext(), "cpowf", "cpow", benefit); + patterns.add>( + patterns.getContext(), "csqrtf", "csqrt", benefit); + patterns.add>( + patterns.getContext(), "ctanhf", "ctanh", benefit); + patterns.add>( + patterns.getContext(), "ccosf", "ccos", benefit); + patterns.add>( + patterns.getContext(), "csinf", "csin", benefit); + patterns.add>( + patterns.getContext(), "conjf", "conj", benefit); + patterns.add>( + patterns.getContext(), "cabsf", "cabs", benefit); } namespace { @@ -96,7 +122,8 @@ ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalOp(); + target.addIllegalOp(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir b/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir --- a/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir +++ b/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir @@ -9,6 +9,7 @@ // CHECK-DAG: @ccos(complex) -> complex // CHECK-DAG: @csin(complex) -> complex // CHECK-DAG: @conj(complex) -> complex +// CHECK-DAG: @cabs(complex) -> f64 // CHECK-LABEL: func @cpow_caller // CHECK-SAME: %[[FLOAT:.*]]: complex @@ -80,4 +81,16 @@ %double_result = complex.conj %double : complex // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] return %float_result, %double_result : complex, complex +} + +// CHECK-LABEL: func @cabs_caller +// CHECK-SAME: %[[FLOAT:.*]]: complex +// CHECK-SAME: %[[DOUBLE:.*]]: complex +func.func @cabs_caller(%float: complex, %double: complex) -> (f32, f64) { + // CHECK: %[[FLOAT_RESULT:.*]] = call @cabsf(%[[FLOAT]]) + %float_result = complex.abs %float : complex + // CHECK: %[[DOUBLE_RESULT:.*]] = call @cabs(%[[DOUBLE]]) + %double_result = complex.abs %double : complex + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 } \ No newline at end of file