diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -63,6 +63,7 @@ def LLVM_SinOp : LLVM_UnaryIntrinsicOp<"sin">; def LLVM_RoundEvenOp : LLVM_UnaryIntrinsicOp<"roundeven">; def LLVM_RoundOp : LLVM_UnaryIntrinsicOp<"round">; +def LLVM_FTruncOp : LLVM_UnaryIntrinsicOp<"trunc">; def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">; def LLVM_PowOp : LLVM_BinarySameArgsIntrinsicOp<"pow">; def LLVM_PowIOp : LLVM_BinaryIntrinsicOp<"powi">; diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -802,6 +802,35 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// TruncOp +//===----------------------------------------------------------------------===// + +def Math_TruncOp : Math_FloatUnaryOp<"trunc"> { + let summary = "trunc of the specified value"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `math.trunc` ssa-use `:` type + ``` + + The `trunc` operation returns the operand rounded to the nearest integer + value in floating-point format. It takes one operand of floating point type + (i.e., scalar, tensor or vector) and produces one result of the same type. + The operation always rounds to the nearest integer not larger in magnitude + than the operand, regardless of the current rounding direction. + + Example: + + ```mlir + // Scalar trunc operation. + %a = math.trunc %b : f64 + ``` + }]; + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // FPowIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -47,6 +47,8 @@ VectorConvertToLLVMPattern; using SinOpLowering = VectorConvertToLLVMPattern; using SqrtOpLowering = VectorConvertToLLVMPattern; +using FTruncOpLowering = + VectorConvertToLLVMPattern; // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`. template @@ -297,7 +299,8 @@ RoundOpLowering, RsqrtOpLowering, SinOpLowering, - SqrtOpLowering + SqrtOpLowering, + FTruncOpLowering >(converter); // clang-format on } diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -155,19 +155,19 @@ void mlir::populateMathToLibmConversionPatterns( RewritePatternSet &patterns, PatternBenefit benefit, llvm::Optional log1pBenefit) { - patterns - .add, VecOpToScalarOp, - VecOpToScalarOp, VecOpToScalarOp, - VecOpToScalarOp, VecOpToScalarOp, - VecOpToScalarOp, VecOpToScalarOp, - VecOpToScalarOp, VecOpToScalarOp>( - patterns.getContext(), benefit); + patterns.add, VecOpToScalarOp, + VecOpToScalarOp, VecOpToScalarOp, + VecOpToScalarOp, VecOpToScalarOp, + VecOpToScalarOp, + VecOpToScalarOp, VecOpToScalarOp, + VecOpToScalarOp, VecOpToScalarOp>( + patterns.getContext(), benefit); patterns.add, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, - PromoteOpToF32, PromoteOpToF32>( - patterns.getContext(), benefit); + PromoteOpToF32, PromoteOpToF32, + PromoteOpToF32>(patterns.getContext(), benefit); patterns.add>(patterns.getContext(), "atanf", "atan", benefit); patterns.add>(patterns.getContext(), @@ -194,6 +194,8 @@ "floorf", "floor", benefit); patterns.add>(patterns.getContext(), "ceilf", "ceil", benefit); + patterns.add>(patterns.getContext(), + "truncf", "trunc", benefit); } namespace { diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -445,6 +445,24 @@ }); } +//===----------------------------------------------------------------------===// +// TruncOp folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::TruncOp::fold(ArrayRef operands) { + return constFoldUnaryOpConditional( + operands, [](const APFloat &a) -> Optional { + switch (a.getSizeInBits(a.getSemantics())) { + case 64: + return APFloat(trunc(a.convertToDouble())); + case 32: + return APFloat(truncf(a.convertToFloat())); + default: + return {}; + } + }); +} + /// Materialize an integer or floating point constant. Operation *math::MathDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir --- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir +++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir @@ -200,3 +200,13 @@ %0 = math.roundeven %arg0 : f32 func.return } + +// ----- + +// CHECK-LABEL: func @trunc( +// CHECK-SAME: f32 +func.func @trunc(%arg0 : f32) { + // CHECK: "llvm.intr.trunc"(%arg0) : (f32) -> f32 + %0 = math.trunc %arg0 : f32 + func.return +} diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir --- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir +++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir @@ -16,6 +16,8 @@ // CHECK-DAG: @roundf(f32) -> f32 attributes {llvm.readnone} // CHECK-DAG: @roundeven(f64) -> f64 attributes {llvm.readnone} // CHECK-DAG: @roundevenf(f32) -> f32 attributes {llvm.readnone} +// CHECK-DAG: @trunc(f64) -> f64 attributes {llvm.readnone} +// CHECK-DAG: @truncf(f32) -> f32 attributes {llvm.readnone} // CHECK-DAG: @cos(f64) -> f64 attributes {llvm.readnone} // CHECK-DAG: @cosf(f32) -> f32 attributes {llvm.readnone} // CHECK-DAG: @sin(f64) -> f64 attributes {llvm.readnone} @@ -227,6 +229,17 @@ return %float_result, %double_result : f32, f64 } +// CHECK-LABEL: func @trunc_caller +// CHECK-SAME: %[[FLOAT:.*]]: f32 +// CHECK-SAME: %[[DOUBLE:.*]]: f64 +func.func @trunc_caller(%float: f32, %double: f64) -> (f32, f64) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @truncf(%[[FLOAT]]) : (f32) -> f32 + %float_result = math.trunc %float : f32 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @trunc(%[[DOUBLE]]) : (f64) -> f64 + %double_result = math.trunc %double : f64 + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 +} // CHECK-LABEL: func @cos_caller // CHECK-SAME: %[[FLOAT:.*]]: f32 @@ -300,6 +313,29 @@ return %float_result, %double_result : vector<2xf32>, vector<2xf64> } +// CHECK-LABEL: func @trunc_vec_caller( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) { +func.func @trunc_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) { + // CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> + // CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64> + // CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : vector<2xf32> + // CHECK: %[[OUT0_F32:.*]] = call @truncf(%[[IN0_F32]]) : (f32) -> f32 + // CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32> + // CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : vector<2xf32> + // CHECK: %[[OUT1_F32:.*]] = call @truncf(%[[IN1_F32]]) : (f32) -> f32 + // CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32> + %float_result = math.trunc %float : vector<2xf32> + // CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : vector<2xf64> + // CHECK: %[[OUT0_F64:.*]] = call @trunc(%[[IN0_F64]]) : (f64) -> f64 + // CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64> + // CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : vector<2xf64> + // CHECK: %[[OUT1_F64:.*]] = call @trunc(%[[IN1_F64]]) : (f64) -> f64 + // CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64> + %double_result = math.trunc %double : vector<2xf64> + // CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64> + return %float_result, %double_result : vector<2xf32>, vector<2xf64> +} // CHECK-LABEL: func @tan_caller // CHECK-SAME: %[[FLOAT:.*]]: f32 diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir --- a/mlir/test/Dialect/Math/canonicalize.mlir +++ b/mlir/test/Dialect/Math/canonicalize.mlir @@ -411,3 +411,21 @@ %0 = math.round %v1 : vector<4xf32> return %0 : vector<4xf32> } + +// CHECK-LABEL: @trunc_fold +// CHECK-NEXT: %[[cst:.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-NEXT: return %[[cst]] +func.func @trunc_fold() -> f32 { + %c = arith.constant 1.1 : f32 + %r = math.trunc %c : f32 + return %r : f32 +} + +// CHECK-LABEL: @trunc_fold_vec +// CHECK-NEXT: %[[cst:.+]] = arith.constant dense<[0.000000e+00, -0.000000e+00, 1.000000e+00, -1.000000e+00]> : vector<4xf32> +// CHECK-NEXT: return %[[cst]] +func.func @trunc_fold_vec() -> (vector<4xf32>) { + %v = arith.constant dense<[0.5, -0.5, 1.5, -1.5]> : vector<4xf32> + %0 = math.trunc %v : vector<4xf32> + return %0 : vector<4xf32> +} diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir --- a/mlir/test/Dialect/Math/ops.mlir +++ b/mlir/test/Dialect/Math/ops.mlir @@ -257,3 +257,15 @@ %2 = math.ipowi %t, %t : tensor<4x4x?xi32> return } + +// CHECK-LABEL: func @trunc( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func.func @trunc(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.trunc %[[F]] : f32 + %0 = math.trunc %f : f32 + // CHECK: %{{.*}} = math.trunc %[[V]] : vector<4xf32> + %1 = math.trunc %v : vector<4xf32> + // CHECK: %{{.*}} = math.trunc %[[T]] : tensor<4x4x?xf32> + %2 = math.trunc %t : tensor<4x4x?xf32> + return +}