diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -476,7 +476,7 @@ let constructor = "mlir::createConvertMathToLibmPass()"; let dependentDialects = [ "arith::ArithmeticDialect", - "func::FuncDialect", + "LLVM::LLVMDialect", "vector::VectorDialect", ]; } diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -10,7 +10,7 @@ #include "../PassDetail.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -112,28 +112,28 @@ LogicalResult ScalarOpToLibmCall::matchAndRewrite(Op op, PatternRewriter &rewriter) const { + assert(op->getNumResults() == 1 && "expect 1 result only"); auto module = SymbolTable::getNearestSymbolTable(op); auto type = op.getType(); if (!type.template isa()) return failure(); auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; - auto opFunc = dyn_cast_or_null( + LLVM::LLVMFuncOp opFunc = dyn_cast_or_null( SymbolTable::lookupSymbolIn(module, name)); // Forward declare function if it hasn't already been if (!opFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&module->getRegion(0).front()); - auto opFunctionTy = FunctionType::get( - rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); - opFunc = rewriter.create(rewriter.getUnknownLoc(), name, - opFunctionTy); + auto opFunctionTy = LLVM::LLVMFunctionType::get( + op->getResultTypes()[0], llvm::to_vector<4>(op->getOperandTypes())); + opFunc = rewriter.create(rewriter.getUnknownLoc(), name, + opFunctionTy); opFunc.setPrivate(); } assert(isa(SymbolTable::lookupSymbolIn(module, name))); - rewriter.replaceOpWithNewOp(op, name, op.getType(), - op->getOperands()); + rewriter.replaceOpWithNewOp(op, opFunc, op->getOperands()); return success(); } @@ -169,7 +169,7 @@ ConversionTarget target(getContext()); target.addLegalDialect(); + LLVM::LLVMDialect, vector::VectorDialect>(); target.addIllegalDialect(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir --- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir +++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir @@ -13,9 +13,9 @@ // CHECK-SAME: %[[FLOAT:.*]]: f32 // CHECK-SAME: %[[DOUBLE:.*]]: f64 func.func @tanh_caller(%float: f32, %double: f64) -> (f32, f64) { - // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @tanhf(%[[FLOAT]]) : (f32) -> f32 + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = llvm.call @tanhf(%[[FLOAT]]) : (f32) -> f32 %float_result = math.tanh %float : f32 - // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @tanh(%[[DOUBLE]]) : (f64) -> f64 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = llvm.call @tanh(%[[DOUBLE]]) : (f64) -> f64 %double_result = math.tanh %double : f64 // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] return %float_result, %double_result : f32, f64 @@ -28,18 +28,18 @@ // CHECK-SAME: %[[HALF:.*]]: f16 // CHECK-SAME: %[[BFLOAT:.*]]: bf16 func.func @atan2_caller(%float: f32, %double: f64, %half: f16, %bfloat: bf16) -> (f32, f64, f16, bf16) { - // CHECK: %[[FLOAT_RESULT:.*]] = call @atan2f(%[[FLOAT]], %[[FLOAT]]) : (f32, f32) -> f32 + // CHECK: %[[FLOAT_RESULT:.*]] = llvm.call @atan2f(%[[FLOAT]], %[[FLOAT]]) : (f32, f32) -> f32 %float_result = math.atan2 %float, %float : f32 - // CHECK: %[[DOUBLE_RESULT:.*]] = call @atan2(%[[DOUBLE]], %[[DOUBLE]]) : (f64, f64) -> f64 + // CHECK: %[[DOUBLE_RESULT:.*]] = llvm.call @atan2(%[[DOUBLE]], %[[DOUBLE]]) : (f64, f64) -> f64 %double_result = math.atan2 %double, %double : f64 // CHECK: %[[HALF_PROMOTED1:.*]] = arith.extf %[[HALF]] : f16 to f32 // CHECK: %[[HALF_PROMOTED2:.*]] = arith.extf %[[HALF]] : f16 to f32 - // CHECK: %[[HALF_CALL:.*]] = call @atan2f(%[[HALF_PROMOTED1]], %[[HALF_PROMOTED2]]) : (f32, f32) -> f32 + // CHECK: %[[HALF_CALL:.*]] = llvm.call @atan2f(%[[HALF_PROMOTED1]], %[[HALF_PROMOTED2]]) : (f32, f32) -> f32 // CHECK: %[[HALF_RESULT:.*]] = arith.truncf %[[HALF_CALL]] : f32 to f16 %half_result = math.atan2 %half, %half : f16 // CHECK: %[[BFLOAT_PROMOTED1:.*]] = arith.extf %[[BFLOAT]] : bf16 to f32 // CHECK: %[[BFLOAT_PROMOTED2:.*]] = arith.extf %[[BFLOAT]] : bf16 to f32 - // CHECK: %[[BFLOAT_CALL:.*]] = call @atan2f(%[[BFLOAT_PROMOTED1]], %[[BFLOAT_PROMOTED2]]) : (f32, f32) -> f32 + // CHECK: %[[BFLOAT_CALL:.*]] = llvm.call @atan2f(%[[BFLOAT_PROMOTED1]], %[[BFLOAT_PROMOTED2]]) : (f32, f32) -> f32 // CHECK: %[[BFLOAT_RESULT:.*]] = arith.truncf %[[BFLOAT_CALL]] : f32 to bf16 %bfloat_result = math.atan2 %bfloat, %bfloat : bf16 // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]], %[[HALF_RESULT]], %[[BFLOAT_RESULT]] @@ -50,9 +50,9 @@ // CHECK-SAME: %[[FLOAT:.*]]: f32 // CHECK-SAME: %[[DOUBLE:.*]]: f64 func.func @erf_caller(%float: f32, %double: f64) -> (f32, f64) { - // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @erff(%[[FLOAT]]) : (f32) -> f32 + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = llvm.call @erff(%[[FLOAT]]) : (f32) -> f32 %float_result = math.erf %float : f32 - // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @erf(%[[DOUBLE]]) : (f64) -> f64 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = llvm.call @erf(%[[DOUBLE]]) : (f64) -> f64 %double_result = math.erf %double : f64 // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] return %float_result, %double_result : f32, f64 @@ -62,9 +62,9 @@ // CHECK-SAME: %[[FLOAT:.*]]: f32 // CHECK-SAME: %[[DOUBLE:.*]]: f64 func.func @expm1_caller(%float: f32, %double: f64) -> (f32, f64) { - // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @expm1f(%[[FLOAT]]) : (f32) -> f32 + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = llvm.call @expm1f(%[[FLOAT]]) : (f32) -> f32 %float_result = math.expm1 %float : f32 - // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @expm1(%[[DOUBLE]]) : (f64) -> f64 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = llvm.call @expm1(%[[DOUBLE]]) : (f64) -> f64 %double_result = math.expm1 %double : f64 // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] return %float_result, %double_result : f32, f64 @@ -81,16 +81,16 @@ // CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> // CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64> // CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : vector<2xf32> -// CHECK: %[[OUT0_F32:.*]] = call @expm1f(%[[IN0_F32]]) : (f32) -> f32 +// CHECK: %[[OUT0_F32:.*]] = llvm.call @expm1f(%[[IN0_F32]]) : (f32) -> f32 // CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32> // CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : vector<2xf32> -// CHECK: %[[OUT1_F32:.*]] = call @expm1f(%[[IN1_F32]]) : (f32) -> f32 +// CHECK: %[[OUT1_F32:.*]] = llvm.call @expm1f(%[[IN1_F32]]) : (f32) -> f32 // CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32> // CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : vector<2xf64> -// CHECK: %[[OUT0_F64:.*]] = call @expm1(%[[IN0_F64]]) : (f64) -> f64 +// CHECK: %[[OUT0_F64:.*]] = llvm.call @expm1(%[[IN0_F64]]) : (f64) -> f64 // CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64> // CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : vector<2xf64> -// CHECK: %[[OUT1_F64:.*]] = call @expm1(%[[IN1_F64]]) : (f64) -> f64 +// CHECK: %[[OUT1_F64:.*]] = llvm.call @expm1(%[[IN1_F64]]) : (f64) -> f64 // CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64> // CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64> // CHECK: } @@ -103,16 +103,16 @@ // CHECK-SAME: %[[VAL:.*]]: vector<2x2xf32> // CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> // CHECK: %[[IN0_0_F32:.*]] = vector.extract %[[VAL]][0, 0] : vector<2x2xf32> -// CHECK: %[[OUT0_0_F32:.*]] = call @expm1f(%[[IN0_0_F32]]) : (f32) -> f32 +// CHECK: %[[OUT0_0_F32:.*]] = llvm.call @expm1f(%[[IN0_0_F32]]) : (f32) -> f32 // CHECK: %[[VAL_1:.*]] = vector.insert %[[OUT0_0_F32]], %[[CVF]] [0, 0] : f32 into vector<2x2xf32> // CHECK: %[[IN0_1_F32:.*]] = vector.extract %[[VAL]][0, 1] : vector<2x2xf32> -// CHECK: %[[OUT0_1_F32:.*]] = call @expm1f(%[[IN0_1_F32]]) : (f32) -> f32 +// CHECK: %[[OUT0_1_F32:.*]] = llvm.call @expm1f(%[[IN0_1_F32]]) : (f32) -> f32 // CHECK: %[[VAL_2:.*]] = vector.insert %[[OUT0_1_F32]], %[[VAL_1]] [0, 1] : f32 into vector<2x2xf32> // CHECK: %[[IN1_0_F32:.*]] = vector.extract %[[VAL]][1, 0] : vector<2x2xf32> -// CHECK: %[[OUT1_0_F32:.*]] = call @expm1f(%[[IN1_0_F32]]) : (f32) -> f32 +// CHECK: %[[OUT1_0_F32:.*]] = llvm.call @expm1f(%[[IN1_0_F32]]) : (f32) -> f32 // CHECK: %[[VAL_3:.*]] = vector.insert %[[OUT1_0_F32]], %[[VAL_2]] [1, 0] : f32 into vector<2x2xf32> // CHECK: %[[IN1_1_F32:.*]] = vector.extract %[[VAL]][1, 1] : vector<2x2xf32> -// CHECK: %[[OUT1_1_F32:.*]] = call @expm1f(%[[IN1_1_F32]]) : (f32) -> f32 +// CHECK: %[[OUT1_1_F32:.*]] = llvm.call @expm1f(%[[IN1_1_F32]]) : (f32) -> f32 // CHECK: %[[VAL_4:.*]] = vector.insert %[[OUT1_1_F32]], %[[VAL_3]] [1, 1] : f32 into vector<2x2xf32> // CHECK: return %[[VAL_4]] : vector<2x2xf32> // CHECK: }