diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -892,6 +892,7 @@ rounded using the default rounding mode. When operating on vectors, casts elementwise. }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -906,6 +907,7 @@ rounded using the default rounding mode. When operating on vectors, casts elementwise. }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -919,6 +921,7 @@ towards zero) unsigned integer value. When operating on vectors, casts elementwise. }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -932,6 +935,7 @@ towards zero) signed integer value. When operating on vectors, casts elementwise. }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -14,6 +14,8 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/APSInt.h" + using namespace mlir; using namespace mlir::arith; @@ -881,6 +883,17 @@ return checkIntFloatCast(inputs, outputs); } +OpFoldResult arith::UIToFPOp::fold(ArrayRef operands) { + if (auto lhs = operands[0].dyn_cast_or_null()) { + const APInt &api = lhs.getValue(); + FloatType FT = getType().cast(); + APFloat apf(FT.getFloatSemantics(), APInt::getZero(FT.getWidth())); + apf.convertFromAPInt(api, /*signed*/ false, APFloat::rmNearestTiesToEven); + return FloatAttr::get(FT, apf); + } + return {}; +} + //===----------------------------------------------------------------------===// // SIToFPOp //===----------------------------------------------------------------------===// @@ -889,6 +902,16 @@ return checkIntFloatCast(inputs, outputs); } +OpFoldResult arith::SIToFPOp::fold(ArrayRef operands) { + if (auto lhs = operands[0].dyn_cast_or_null()) { + const APInt &api = lhs.getValue(); + FloatType FT = getType().cast(); + APFloat apf(FT.getFloatSemantics(), APInt::getZero(FT.getWidth())); + apf.convertFromAPInt(api, /*signed*/ true, APFloat::rmNearestTiesToEven); + return FloatAttr::get(FT, apf); + } + return {}; +} //===----------------------------------------------------------------------===// // FPToUIOp //===----------------------------------------------------------------------===// @@ -897,6 +920,25 @@ return checkIntFloatCast(inputs, outputs); } +OpFoldResult arith::FPToUIOp::fold(ArrayRef operands) { + if (auto lhs = operands[0].dyn_cast_or_null()) { + const APFloat &V = lhs.getValue(); + IntegerType IT = getType().cast(); + bool ignored; + uint32_t DestBitWidth = IT.getWidth(); + APSInt IntVal(DestBitWidth, /*unsigned*/ true); + if (APFloat::opInvalidOp == + V.convertToInteger(IntVal, APFloat::rmTowardZero, &ignored)) { + // Undefined behavior invoked - the destination type can't represent + // the input constant. + return {}; + } + return IntegerAttr::get(getType(), IntVal); + } + + return {}; +} + //===----------------------------------------------------------------------===// // FPToSIOp //===----------------------------------------------------------------------===// @@ -905,6 +947,25 @@ return checkIntFloatCast(inputs, outputs); } +OpFoldResult arith::FPToSIOp::fold(ArrayRef operands) { + if (auto lhs = operands[0].dyn_cast_or_null()) { + const APFloat &V = lhs.getValue(); + IntegerType IT = getType().cast(); + bool ignored; + uint32_t DestBitWidth = IT.getWidth(); + APSInt IntVal(DestBitWidth, /*unsigned*/ false); + if (APFloat::opInvalidOp == + V.convertToInteger(IntVal, APFloat::rmTowardZero, &ignored)) { + // Undefined behavior invoked - the destination type can't represent + // the input constant. + return {}; + } + return IntegerAttr::get(getType(), IntVal); + } + + return {}; +} + //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -484,3 +484,41 @@ %3 = arith.minui %arg0, %c0 : i8 return %0, %1, %2, %3: i8, i8, i8, i8 } + +// ----- + +// CHECK-LABEL: @constantFPtoUI( +func @constantFPtoUI() -> i32 { + // CHECK: %[[C0:.+]] = arith.constant 2 : i32 + // CHECK: return %[[C0]] + %c0 = arith.constant 2.0 : f32 + %res = arith.fptoui %c0 : f32 to i32 + return %res : i32 +} + +// CHECK-LABEL: @constantFPtoSI( +func @constantFPtoSI() -> i32 { + // CHECK: %[[C0:.+]] = arith.constant -2 : i32 + // CHECK: return %[[C0]] + %c0 = arith.constant -2.0 : f32 + %res = arith.fptosi %c0 : f32 to i32 + return %res : i32 +} + +// CHECK-LABEL: @constantSItoFP( +func @constantSItoFP() -> f32 { + // CHECK: %[[C0:.+]] = arith.constant -2.000000e+00 : f32 + // CHECK: return %[[C0]] + %c0 = arith.constant -2 : i32 + %res = arith.sitofp %c0 : i32 to f32 + return %res : f32 +} + +// CHECK-LABEL: @constantUItoFP( +func @constantUItoFP() -> f32 { + // CHECK: %[[C0:.+]] = arith.constant 2.000000e+00 : f32 + // CHECK: return %[[C0]] + %c0 = arith.constant 2 : i32 + %res = arith.sitofp %c0 : i32 to f32 + return %res : f32 +}