diff --git a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp --- a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp +++ b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp @@ -29,6 +29,24 @@ LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; +private: + std::string floatFunc, doubleFunc; +}; + +// Pattern to convert scalar complex operations returning single float +// compatible value to calls to the corresponding libm functions. Additionally +// the libm function signatures are declared. +template +struct FloatScalarOpToLibmCall : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + FloatScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc, + StringRef doubleFunc, PatternBenefit benefit) + : OpRewritePattern(context, benefit), floatFunc(floatFunc), + doubleFunc(doubleFunc){}; + + LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; + private: std::string floatFunc, doubleFunc; }; @@ -65,6 +83,41 @@ return success(); } +template +LogicalResult +FloatScalarOpToLibmCall::matchAndRewrite(Op op, + PatternRewriter &rewriter) const { + auto module = SymbolTable::getNearestSymbolTable(op); + auto elementType = op.getType().template cast(); + auto context = rewriter.getContext(); + if (!elementType.template isa()) + return failure(); + + auto name = + elementType.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; + auto resultType = elementType.getIntOrFloatBitWidth() == 64 + ? FloatType::getF64(context) + : FloatType::getF32(context); + auto opFunc = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(module, name)); + // Forward declare function if it hasn't already been + if (!opFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&module->getRegion(0).front()); + auto opFunctionTy = FunctionType::get(rewriter.getContext(), + op->getOperandTypes(), resultType); + opFunc = rewriter.create(rewriter.getUnknownLoc(), name, + opFunctionTy); + opFunc.setPrivate(); + } + assert(isa(SymbolTable::lookupSymbolIn(module, name))); + + rewriter.replaceOpWithNewOp(op, name, elementType, + op->getOperands()); + + return success(); +} + void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add>(patterns.getContext(), @@ -77,6 +130,8 @@ "ccosf", "ccos", benefit); patterns.add>(patterns.getContext(), "csinf", "csin", benefit); + patterns.add>( + patterns.getContext(), "cabsf", "cabs", benefit); } namespace { @@ -94,7 +149,8 @@ ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalOp(); + target.addIllegalOp(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir b/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir --- a/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir +++ b/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir @@ -8,6 +8,7 @@ // CHECK-DAG: @ctanh(complex) -> complex // CHECK-DAG: @ccos(complex) -> complex // CHECK-DAG: @csin(complex) -> complex +// CHECK-DAG: @cabs(complex) -> f64 // CHECK-LABEL: func @cpow_caller // CHECK-SAME: %[[FLOAT:.*]]: complex @@ -67,4 +68,16 @@ %double_result = complex.sin %double : complex // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] return %float_result, %double_result : complex, complex +} + +// CHECK-LABEL: func @cabs_caller +// CHECK-SAME: %[[FLOAT:.*]]: complex +// CHECK-SAME: %[[DOUBLE:.*]]: complex +func.func @cabs_caller(%float: complex, %double: complex) -> (f32, f64) { + // CHECK: %[[FLOAT_RESULT:.*]] = call @cabsf(%[[FLOAT]]) + %float_result = complex.abs %float : complex + // CHECK: %[[DOUBLE_RESULT:.*]] = call @cabs(%[[DOUBLE]]) + %double_result = complex.abs %double : complex + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 } \ No newline at end of file