diff --git a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp --- a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp +++ b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp @@ -16,14 +16,43 @@ using namespace mlir; namespace { +// Functor to resolve the function name corresponding to the given complex +// result type. +struct ComplexTypeResolver { + llvm::Optional operator()(Type type) const { + auto complexType = type.cast(); + auto elementType = complexType.getElementType(); + if (!elementType.isa()) + return {}; + + return elementType.getIntOrFloatBitWidth() == 64; + } +}; + +// Functor to resolve the function name corresponding to the given float result +// type. +struct FloatTypeResolver { + llvm::Optional operator()(Type type) const { + auto elementType = type.cast(); + if (!elementType.isa()) + return {}; + + return elementType.getIntOrFloatBitWidth() == 64; + } +}; + // Pattern to convert scalar complex operations to calls to libm functions. // Additionally the libm function signatures are declared. -template +// TypeResolver is a functor returning the libm function name according to the +// expected type double or float. +template struct ScalarOpToLibmCall : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc, - StringRef doubleFunc, PatternBenefit benefit) + ScalarOpToLibmCall(MLIRContext *context, + StringRef floatFunc, + StringRef doubleFunc, + PatternBenefit benefit) : OpRewritePattern(context, benefit), floatFunc(floatFunc), doubleFunc(doubleFunc){}; @@ -34,18 +63,16 @@ }; } // namespace -template -LogicalResult -ScalarOpToLibmCall::matchAndRewrite(Op op, - PatternRewriter &rewriter) const { +template +LogicalResult ScalarOpToLibmCall::matchAndRewrite( + Op op, PatternRewriter &rewriter) const { auto module = SymbolTable::getNearestSymbolTable(op); - auto type = op.getType().template cast(); - Type elementType = type.getElementType(); - if (!elementType.isa()) + auto isDouble = TypeResolver()(op.getType()); + if (!isDouble.hasValue()) return failure(); - auto name = - elementType.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; + auto name = isDouble.value() ? doubleFunc : floatFunc; + auto opFunc = dyn_cast_or_null( SymbolTable::lookupSymbolIn(module, name)); // Forward declare function if it hasn't already been @@ -60,7 +87,8 @@ } assert(isa(SymbolTable::lookupSymbolIn(module, name))); - rewriter.replaceOpWithNewOp(op, name, type, op->getOperands()); + rewriter.replaceOpWithNewOp(op, name, op.getType(), + op->getOperands()); return success(); } @@ -79,6 +107,8 @@ "csinf", "csin", benefit); patterns.add>(patterns.getContext(), "conjf", "conj", benefit); + patterns.add>( + patterns.getContext(), "cabsf", "cabs", benefit); } namespace { @@ -96,7 +126,8 @@ ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalOp(); + target.addIllegalOp(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir b/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir --- a/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir +++ b/mlir/test/Conversion/ComplexToLibm/convert-to-libm.mlir @@ -9,6 +9,7 @@ // CHECK-DAG: @ccos(complex) -> complex // CHECK-DAG: @csin(complex) -> complex // CHECK-DAG: @conj(complex) -> complex +// CHECK-DAG: @cabs(complex) -> f64 // CHECK-LABEL: func @cpow_caller // CHECK-SAME: %[[FLOAT:.*]]: complex @@ -80,4 +81,16 @@ %double_result = complex.conj %double : complex // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] return %float_result, %double_result : complex, complex +} + +// CHECK-LABEL: func @cabs_caller +// CHECK-SAME: %[[FLOAT:.*]]: complex +// CHECK-SAME: %[[DOUBLE:.*]]: complex +func.func @cabs_caller(%float: complex, %double: complex) -> (f32, f64) { + // CHECK: %[[FLOAT_RESULT:.*]] = call @cabsf(%[[FLOAT]]) + %float_result = complex.abs %float : complex + // CHECK: %[[DOUBLE_RESULT:.*]] = call @cabs(%[[DOUBLE]]) + %double_result = complex.abs %double : complex + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 } \ No newline at end of file