diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1047,6 +1047,7 @@ When operating on vectors, casts elementwise. }]; let hasVerifier = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1224,6 +1224,16 @@ // ExtFOp //===----------------------------------------------------------------------===// +/// Always fold extension of FP constants. +OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) { + auto constOperand = adaptor.getIn().dyn_cast_or_null(); + if (!constOperand) + return {}; + + // Convert to target type via 'double'. + return FloatAttr::get(getType(), constOperand.getValue().convertToDouble()); +} + bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { return checkWidthChangeCast(inputs, outputs); } diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -502,6 +502,15 @@ return %ext : vector<4xi16> } +// CHECK-LABEL: @extFPConstant +// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : f64 +// CHECK: return %[[cres]] +func.func @extFPConstant() -> f64 { + %cst = arith.constant 1.000000e+00 : f32 + %0 = arith.extf %cst : f32 to f64 + return %0 : f64 +} + // CHECK-LABEL: @truncConstant // CHECK: %[[cres:.+]] = arith.constant -2 : i16 // CHECK: return %[[cres]]