diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -18,6 +18,7 @@ void populateExpandTanhPattern(RewritePatternSet &patterns); void populateExpandFmaFPattern(RewritePatternSet &patterns); void populateExpandFloorFPattern(RewritePatternSet &patterns); +void populateExpandCeilFPattern(RewritePatternSet &patterns); void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns); struct MathPolynomialApproximationOptions { diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -46,6 +46,13 @@ return b.create(loc, attr); } +static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) { + Type opType = operand.getType(); + Value fixedConvert = b.create(b.getI64Type(), operand); + Value fpFixedConvert = b.create(opType, fixedConvert); + return fpFixedConvert; +} + /// Expands tanh op into /// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0 /// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0 @@ -112,8 +119,7 @@ ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operand = op.getOperand(); Type opType = operand.getType(); - Value fixedConvert = b.create(b.getI64Type(), operand); - Value fpFixedConvert = b.create(opType, fixedConvert); + Value fpFixedConvert = createTruncatedFPValue(operand, b); // Creating constants for later use. Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); @@ -128,6 +134,30 @@ return success(); } +// Converts a ceilf() function to the following: +// ceilf(float x) -> +// y = (float)(int) x +// if (x > y) then incr = 1 else incr = 0 +// y = y + incr <= replace this op with the ceilf op. +static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + Value operand = op.getOperand(); + Type opType = operand.getType(); + Value fpFixedConvert = createTruncatedFPValue(operand, b); + + // Creating constants for later use. + Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); + Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter); + + Value gtCheck = b.create(arith::CmpFPredicate::OGT, operand, + fpFixedConvert); + Value incrValue = b.create(op->getLoc(), gtCheck, one, zero); + + Value ret = b.create(opType, fpFixedConvert, incrValue); + rewriter.replaceOp(op, ret); + return success(); +} + // Converts math.ctlz to scf and arith operations. This is done // by performing a binary search on the bits. static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, @@ -187,6 +217,11 @@ void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) { patterns.add(convertFmaFOp); } + +void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) { + patterns.add(convertCeilOp); +} + void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) { patterns.add(convertFloorOp); } diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -148,3 +148,20 @@ %ret = math.floor %a : f64 return %ret : f64 } + +// ----- + +// CHECK-LABEL: func @ceilf_func +// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 +func.func @ceilf_func(%a: f64) -> f64 { + // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000 + // CHECK-DAG: [[CST_0:%.+]] = arith.constant 1.000 + // CHECK-NEXT: [[CVTI:%.+]] = arith.fptosi [[ARG0]] + // CHECK-NEXT: [[CVTF:%.+]] = arith.sitofp [[CVTI]] + // CHECK-NEXT: [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[CVTF]] + // CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]] + // CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[CVTF]], [[INCR]] + // CHECK-NEXT: return [[ADDF]] + %ret = math.ceil %a : f64 + return %ret : f64 +} diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp --- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp +++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp @@ -41,6 +41,7 @@ populateExpandTanhPattern(patterns); populateExpandFmaFPattern(patterns); populateExpandFloorFPattern(patterns); + populateExpandCeilFPattern(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir --- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir +++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir @@ -647,6 +647,43 @@ return } +// -------------------------------------------------------------------------- // +// ceil. +// -------------------------------------------------------------------------- // +func.func @func_ceilf32(%a : f32) { + %r = math.ceil %a : f32 + vector.print %r : f32 + return +} + +func.func @ceilf() { + // CHECK: 4 + %a = arith.constant 3.8 : f32 + call @func_ceilf32(%a) : (f32) -> () + + // CHECK: -3 + %b = arith.constant -3.8 : f32 + call @func_ceilf32(%b) : (f32) -> () + + // CHECK: 0 + %c = arith.constant 0.0 : f32 + call @func_ceilf32(%c) : (f32) -> () + + // CHECK: -4 + %d = arith.constant -4.2 : f32 + call @func_ceilf32(%d) : (f32) -> () + + // CHECK: -495 + %e = arith.constant -495.0 : f32 + call @func_ceilf32(%e) : (f32) -> () + + // CHECK: 495 + %f = arith.constant 495.0 : f32 + call @func_ceilf32(%f) : (f32) -> () + + return +} + func.func @main() { call @tanh(): () -> () call @log(): () -> () @@ -661,6 +698,7 @@ call @atan2() : () -> () call @cbrt() : () -> () call @floorf() : () -> () + call @ceilf() : () -> () return }