diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -892,6 +892,7 @@ rounded using the default rounding mode. When operating on vectors, casts elementwise. }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -906,6 +907,7 @@ rounded using the default rounding mode. When operating on vectors, casts elementwise. }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -919,6 +921,7 @@ towards zero) unsigned integer value. When operating on vectors, casts elementwise. }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -932,6 +935,7 @@ towards zero) signed integer value. When operating on vectors, casts elementwise. }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -14,6 +14,8 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/APSInt.h" + using namespace mlir; using namespace mlir::arith; @@ -881,6 +883,18 @@ return checkIntFloatCast(inputs, outputs); } +OpFoldResult arith::UIToFPOp::fold(ArrayRef operands) { + if (auto lhs = operands[0].dyn_cast_or_null()) { + const APInt &api = lhs.getValue(); + FloatType floatTy = getType().cast(); + APFloat apf(floatTy.getFloatSemantics(), + APInt::getZero(floatTy.getWidth())); + apf.convertFromAPInt(api, /*signed=*/false, APFloat::rmNearestTiesToEven); + return FloatAttr::get(floatTy, apf); + } + return {}; +} + //===----------------------------------------------------------------------===// // SIToFPOp //===----------------------------------------------------------------------===// @@ -889,6 +903,17 @@ return checkIntFloatCast(inputs, outputs); } +OpFoldResult arith::SIToFPOp::fold(ArrayRef operands) { + if (auto lhs = operands[0].dyn_cast_or_null()) { + const APInt &api = lhs.getValue(); + FloatType floatTy = getType().cast(); + APFloat apf(floatTy.getFloatSemantics(), + APInt::getZero(floatTy.getWidth())); + apf.convertFromAPInt(api, /*signed=*/true, APFloat::rmNearestTiesToEven); + return FloatAttr::get(floatTy, apf); + } + return {}; +} //===----------------------------------------------------------------------===// // FPToUIOp //===----------------------------------------------------------------------===// @@ -897,6 +922,24 @@ return checkIntFloatCast(inputs, outputs); } +OpFoldResult arith::FPToUIOp::fold(ArrayRef operands) { + if (auto lhs = operands[0].dyn_cast_or_null()) { + const APFloat &apf = lhs.getValue(); + IntegerType intTy = getType().cast(); + bool ignored; + APSInt api(intTy.getWidth(), /*unsigned=*/true); + if (APFloat::opInvalidOp == + apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { + // Undefined behavior invoked - the destination type can't represent + // the input constant. + return {}; + } + return IntegerAttr::get(getType(), api); + } + + return {}; +} + //===----------------------------------------------------------------------===// // FPToSIOp //===----------------------------------------------------------------------===// @@ -905,6 +948,24 @@ return checkIntFloatCast(inputs, outputs); } +OpFoldResult arith::FPToSIOp::fold(ArrayRef operands) { + if (auto lhs = operands[0].dyn_cast_or_null()) { + const APFloat &apf = lhs.getValue(); + IntegerType intTy = getType().cast(); + bool ignored; + APSInt api(intTy.getWidth(), /*unsigned=*/false); + if (APFloat::opInvalidOp == + apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { + // Undefined behavior invoked - the destination type can't represent + // the input constant. + return {}; + } + return IntegerAttr::get(getType(), api); + } + + return {}; +} + //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -484,3 +484,65 @@ %3 = arith.minui %arg0, %c0 : i8 return %0, %1, %2, %3: i8, i8, i8, i8 } + +// ----- + +// CHECK-LABEL: @constant_FPtoUI( +func @constant_FPtoUI() -> i32 { + // CHECK: %[[C0:.+]] = arith.constant 2 : i32 + // CHECK: return %[[C0]] + %c0 = arith.constant 2.0 : f32 + %res = arith.fptoui %c0 : f32 to i32 + return %res : i32 +} + +// ----- +// CHECK-LABEL: @invalid_constant_FPtoUI( +func @invalid_constant_FPtoUI() -> i32 { + // CHECK: %[[C0:.+]] = arith.constant -2.000000e+00 : f32 + // CHECK: %[[C1:.+]] = arith.fptoui %[[C0]] : f32 to i32 + // CHECK: return %[[C1]] + %c0 = arith.constant -2.0 : f32 + %res = arith.fptoui %c0 : f32 to i32 + return %res : i32 +} + +// ----- +// CHECK-LABEL: @constant_FPtoSI( +func @constant_FPtoSI() -> i32 { + // CHECK: %[[C0:.+]] = arith.constant -2 : i32 + // CHECK: return %[[C0]] + %c0 = arith.constant -2.0 : f32 + %res = arith.fptosi %c0 : f32 to i32 + return %res : i32 +} + +// ----- +// CHECK-LABEL: @invalid_constant_FPtoSI( +func @invalid_constant_FPtoSI() -> i8 { + // CHECK: %[[C0:.+]] = arith.constant 2.000000e+10 : f32 + // CHECK: %[[C1:.+]] = arith.fptosi %[[C0]] : f32 to i8 + // CHECK: return %[[C1]] + %c0 = arith.constant 2.0e10 : f32 + %res = arith.fptosi %c0 : f32 to i8 + return %res : i8 +} + +// CHECK-LABEL: @constant_SItoFP( +func @constant_SItoFP() -> f32 { + // CHECK: %[[C0:.+]] = arith.constant -2.000000e+00 : f32 + // CHECK: return %[[C0]] + %c0 = arith.constant -2 : i32 + %res = arith.sitofp %c0 : i32 to f32 + return %res : f32 +} + +// ----- +// CHECK-LABEL: @constant_UItoFP( +func @constant_UItoFP() -> f32 { + // CHECK: %[[C0:.+]] = arith.constant 2.000000e+00 : f32 + // CHECK: return %[[C0]] + %c0 = arith.constant 2 : i32 + %res = arith.sitofp %c0 : i32 to f32 + return %res : f32 +}