diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -652,4 +652,30 @@ }]; } +//===----------------------------------------------------------------------===// +// RoundOp +//===----------------------------------------------------------------------===// + +def Math_RoundOp : Math_FloatUnaryOp<"round"> { + let summary = "round of the specified value"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `math.round` ssa-use `:` type + ``` + + The `round` operation returns the operand rounded to the nearest integer + value in floating-point format. It takes one operand of floating point type + (i.e., scalar, tensor or vector) and produces one result of the same type. + + Example: + + ```mlir + // Scalar round operation. + %a = math.round %b : f64 + ``` + }]; +} + #endif // MATH_OPS diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -37,6 +37,8 @@ using PowFOpLowering = VectorConvertToLLVMPattern; using SinOpLowering = VectorConvertToLLVMPattern; using SqrtOpLowering = VectorConvertToLLVMPattern; +using RoundOpLowering = + VectorConvertToLLVMPattern; // A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`. template @@ -285,7 +287,8 @@ PowFOpLowering, RsqrtOpLowering, SinOpLowering, - SqrtOpLowering + SqrtOpLowering, + RoundOpLowering >(converter); // clang-format on } diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -152,6 +152,8 @@ "expm1f", "expm1", benefit); patterns.add>(patterns.getContext(), "tanhf", "tanh", benefit); + patterns.add>(patterns.getContext(), + "roundf", "round", benefit); } namespace { diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir --- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir +++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir @@ -172,3 +172,12 @@ func.return } +// ----- + +// CHECK-LABEL: func @round( +// CHECK-SAME: f32 +func.func @round(%arg0 : f32) { + // CHECK: "llvm.intr.round"(%arg0) : (f32) -> f32 + %0 = math.round %arg0 : f32 + func.return +} diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir --- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir +++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir @@ -8,6 +8,8 @@ // CHECK-DAG: @atan2f(f32, f32) -> f32 // CHECK-DAG: @tanh(f64) -> f64 // CHECK-DAG: @tanhf(f32) -> f32 +// CHECK-DAG: @round(f64) -> f64 +// CHECK-DAG: @roundf(f32) -> f32 // CHECK-LABEL: func @tanh_caller // CHECK-SAME: %[[FLOAT:.*]]: f32 @@ -21,7 +23,6 @@ return %float_result, %double_result : f32, f64 } - // CHECK-LABEL: func @atan2_caller // CHECK-SAME: %[[FLOAT:.*]]: f32 // CHECK-SAME: %[[DOUBLE:.*]]: f64 @@ -116,3 +117,15 @@ // CHECK: %[[VAL_4:.*]] = vector.insert %[[OUT1_1_F32]], %[[VAL_3]] [1, 1] : f32 into vector<2x2xf32> // CHECK: return %[[VAL_4]] : vector<2x2xf32> // CHECK: } + +// CHECK-LABEL: func @round_caller +// CHECK-SAME: %[[FLOAT:.*]]: f32 +// CHECK-SAME: %[[DOUBLE:.*]]: f64 +func.func @round_caller(%float: f32, %double: f64) -> (f32, f64) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @roundf(%[[FLOAT]]) : (f32) -> f32 + %float_result = math.round %float : f32 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @round(%[[DOUBLE]]) : (f64) -> f64 + %double_result = math.round %double : f64 + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 +} diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir --- a/mlir/test/Dialect/Math/ops.mlir +++ b/mlir/test/Dialect/Math/ops.mlir @@ -194,3 +194,15 @@ %2 = math.tanh %t : tensor<4x4x?xf32> return } + +// CHECK-LABEL: func @round( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func.func @round(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.round %[[F]] : f32 + %0 = math.round %f : f32 + // CHECK: %{{.*}} = math.round %[[V]] : vector<4xf32> + %1 = math.round %v : vector<4xf32> + // CHECK: %{{.*}} = math.round %[[T]] : tensor<4x4x?xf32> + %2 = math.round %t : tensor<4x4x?xf32> + return +}