diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1220,6 +1220,8 @@ If the value cannot be exactly represented, it is rounded using the default rounding mode. When operating on vectors, casts elementwise. }]; + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1414,6 +1414,27 @@ return areVectorCastSimpleCompatible(a, b, areCastCompatible); } +/// Perform safe const propagation for fptrunc, i.e. only propagate +/// if FP value can be represented without precision loss or rounding. +OpFoldResult FPTruncOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "unary operation takes one operand"); + + auto constOperand = operands.front(); + if (!constOperand || !constOperand.isa()) + return {}; + + // Convert to target type via 'double'. + double sourceValue = + constOperand.dyn_cast().getValue().convertToDouble(); + auto targetAttr = FloatAttr::get(getType(), sourceValue); + + // Propagate if constant's value does not change after truncation. + if (sourceValue == targetAttr.getValue().convertToDouble()) + return targetAttr; + + return {}; +} + //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -196,11 +196,10 @@ } // CHECK-LABEL: @generalize_soft_plus_2d_f32 -// CHECK: %[[C1:.+]] = constant 1.000000e+00 : f64 +// CHECK: %[[C1:.+]] = constant 1.000000e+00 : f32 // CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32 -// CHECK-NEXT: %[[C1_CAST:.+]] = fptrunc %[[C1]] : f64 to f32 // CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32 -// CHECK-NEXT: %[[SUM:.+]] = addf %[[C1_CAST]], %[[EXP]] : f32 +// CHECK-NEXT: %[[SUM:.+]] = addf %[[C1]], %[[EXP]] : f32 // CHECK-NEXT: %[[LOG:.+]] = math.log %[[SUM]] : f32 // CHECK-NEXT: linalg.yield %[[LOG]] : f32 // CHECK-NEXT: -> tensor<16x32xf32> diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -80,6 +80,25 @@ return %tr : i16 } +// CHECK-LABEL: @truncFPConstant +// CHECK: %[[cres:.+]] = constant 1.000000e+00 : bf16 +// CHECK: return %[[cres]] +func @truncFPConstant() -> bf16 { + %cst = constant 1.000000e+00 : f32 + %0 = fptrunc %cst : f32 to bf16 + return %0 : bf16 +} + +// Test that cases with rounding are NOT propagated +// CHECK-LABEL: @truncFPConstantRounding +// CHECK: constant 1.444000e+25 : f32 +// CHECK: fptrunc +func @truncFPConstantRounding() -> bf16 { + %cst = constant 1.444000e+25 : f32 + %0 = fptrunc %cst : f32 to bf16 + return %0 : bf16 +} + // CHECK-LABEL: @tripleAddAdd // CHECK: %[[cres:.+]] = constant 59 : index // CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index