diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -20,6 +20,7 @@ void populateExpandFloorFPattern(RewritePatternSet &patterns); void populateExpandCeilFPattern(RewritePatternSet &patterns); void populateExpandExp2FPattern(RewritePatternSet &patterns); +void populateExpandRoundFPattern(RewritePatternSet &patterns); void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns); struct MathPolynomialApproximationOptions { diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -174,6 +174,28 @@ return success(); } +static LogicalResult convertRoundOp(math::RoundOp op, + PatternRewriter &rewriter) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + Value operand = op.getOperand(); + Type opType = operand.getType(); + + // Creating constants for later use. + Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); + Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter); + Value negHalf = createFloatConst(op->getLoc(), opType, -0.5, rewriter); + + Value posCheck = + b.create(arith::CmpFPredicate::OGE, operand, zero); + Value incrValue = + b.create(op->getLoc(), posCheck, half, negHalf); + Value add = b.create(opType, operand, incrValue); + + Value fpFixedConvert = createTruncatedFPValue(add, b); + rewriter.replaceOp(op, fpFixedConvert); + return success(); +} + // Converts math.ctlz to scf and arith operations. This is done // by performing a binary search on the bits. static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, @@ -242,6 +264,10 @@ patterns.add(convertExp2fOp); } +void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) { + patterns.add(convertRoundOp); +} + void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) { patterns.add(convertFloorOp); } diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -189,3 +189,21 @@ %ret = math.exp2 %a : tensor<1xf32> return %ret : tensor<1xf32> } + +// ----- + +// CHECK-LABEL: func @roundf_func +// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 +func.func @roundf_func(%a: f64) -> f64 { + // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000 + // CHECK-DAG: [[CST_0:%.+]] = arith.constant 5.000000e-01 + // CHECK-DAG: [[CST_1:%.+]] = arith.constant -5.000000e-01 + // CHECK-DAG: [[COMP:%.+]] = arith.cmpf oge, [[ARG0]], [[CST]] + // CHECK-DAG: [[SEL:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST_1]] + // CHECK-DAG: [[ADDF:%.+]] = arith.addf [[ARG0]], [[SEL]] + // CHECK-DAG: [[CVTI:%.+]] = arith.fptosi [[ADDF]] + // CHECK-DAG: [[CVTF:%.+]] = arith.sitofp [[CVTI]] + // CHECK: return [[CVTF]] + %ret = math.round %a : f64 + return %ret : f64 +} diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp --- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp +++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp @@ -43,6 +43,7 @@ populateExpandFmaFPattern(patterns); populateExpandFloorFPattern(patterns); populateExpandCeilFPattern(patterns); + populateExpandRoundFPattern(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir --- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir +++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir @@ -55,7 +55,54 @@ return } +// -------------------------------------------------------------------------- // +// round. +// -------------------------------------------------------------------------- // +func.func @func_roundf(%a : f32) { + %r = math.round %a : f32 + vector.print %r : f32 + return +} + +func.func @roundf() { + // CHECK: 4 + %a = arith.constant 3.8 : f32 + call @func_roundf(%a) : (f32) -> () + + // CHECK: -4 + %b = arith.constant -3.8 : f32 + call @func_roundf(%b) : (f32) -> () + + // CHECK: 0 + %c = arith.constant 0.0 : f32 + call @func_roundf(%c) : (f32) -> () + + // CHECK: -4 + %d = arith.constant -4.2 : f32 + call @func_roundf(%d) : (f32) -> () + + // CHECK: -495 + %e = arith.constant -495.0 : f32 + call @func_roundf(%e) : (f32) -> () + + // CHECK: 495 + %f = arith.constant 495.0 : f32 + call @func_roundf(%f) : (f32) -> () + + // CHECK: 9 + %g = arith.constant 8.5 : f32 + call @func_roundf(%g) : (f32) -> () + + // CHECK: -9 + %h = arith.constant -8.5 : f32 + call @func_roundf(%h) : (f32) -> () + + return +} + + func.func @main() { call @exp2f() : () -> () + call @roundf() : () -> () return }