diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -246,55 +246,29 @@ f32Ty = shapedTy.clone(f32Ty); } - Value bitcast = b.create(i32Ty, operand); - - Value c23 = createConst(op.getLoc(), i32Ty, 23, rewriter); + // See also lib/ExecutionEngine/Float16bits.cpp . + Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter); + Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter); Value c31 = createConst(op.getLoc(), i32Ty, 31, rewriter); - Value c23Mask = createConst(op.getLoc(), i32Ty, (1 << 23) - 1, rewriter); - Value expMask = - createConst(op.getLoc(), i32Ty, ((1 << 8) - 1) << 23, rewriter); - Value expMax = - createConst(op.getLoc(), i32Ty, ((1 << 8) - 2) << 23, rewriter); - - // Grab the sign bit. - Value sign = b.create(bitcast, c31); - - // Our mantissa rounding value depends on the sign bit and the last - // truncated bit. - Value cManRound = createConst(op.getLoc(), i32Ty, (1 << 15), rewriter); - cManRound = b.create(cManRound, sign); - - // Grab out the mantissa and directly apply rounding. - Value man = b.create(bitcast, c23Mask); - Value manRound = b.create(man, cManRound); - - // Grab the overflow bit and shift right if we overflow. - Value roundBit = b.create(manRound, c23); - Value manNew = b.create(manRound, roundBit); - - // Grab the exponent and round using the mantissa's carry bit. - Value exp = b.create(bitcast, expMask); - Value expCarry = b.create(exp, manRound); - expCarry = b.create(expCarry, expMask); - - // If the exponent is saturated, we keep the max value. - Value expCmp = - b.create(arith::CmpIPredicate::uge, exp, expMax); - exp = b.create(expCmp, exp, expCarry); - - // If the exponent is max and we rolled over, keep the old mantissa. - Value roundBitBool = b.create(i1Ty, roundBit); - Value keepOldMan = b.create(expCmp, roundBitBool); - man = b.create(keepOldMan, man, manNew); - - // Assemble the now rounded f32 value (as an i32). - Value rounded = b.create(sign, c31); - rounded = b.create(rounded, exp); - rounded = b.create(rounded, man); + Value cBias = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter); + Value qNaN = createConst(op.getLoc(), i16Ty, 0x7FC0, rewriter); + Value sNaN = createConst(op.getLoc(), i16Ty, 0xFFC0, rewriter); - Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter); - Value shr = b.create(rounded, c16); - Value trunc = b.create(i16Ty, shr); + Value bitcast = b.create(i32Ty, operand); + Value isNaN = + b.create(arith::CmpFPredicate::UNO, operand, operand); + Value sign = + b.create(i1Ty, b.create(bitcast, c31)); + Value nanVal = b.create(sign, sNaN, qNaN); + + Value lsb = + b.create(b.create(bitcast, c16), c1); + Value roundingBias = b.create(cBias, lsb); + Value biased = b.create(bitcast, roundingBias); + + Value shifted = b.create(biased, c16); + Value truncTypical = b.create(i16Ty, shifted); + Value trunc = b.create(isNaN, nanVal, truncTypical); Value result = b.create(resultTy, trunc); rewriter.replaceOp(op, result); diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir --- a/mlir/test/Dialect/Arith/expand-ops.mlir +++ b/mlir/test/Dialect/Arith/expand-ops.mlir @@ -225,34 +225,24 @@ } // CHECK-LABEL: @truncf_f32 - +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 // CHECK-DAG: %[[C16:.+]] = arith.constant 16 -// CHECK-DAG: %[[C32768:.+]] = arith.constant 32768 -// CHECK-DAG: %[[C2130706432:.+]] = arith.constant 2130706432 -// CHECK-DAG: %[[C2139095040:.+]] = arith.constant 2139095040 -// CHECK-DAG: %[[C8388607:.+]] = arith.constant 8388607 // CHECK-DAG: %[[C31:.+]] = arith.constant 31 -// CHECK-DAG: %[[C23:.+]] = arith.constant 23 +// CHECK-DAG: %[[C32767:.+]] = arith.constant 32767 +// CHECK-DAG: %[[C32704:.+]] = arith.constant 32704 +// CHECK-DAG: %[[C_64:.+]] = arith.constant -64 // CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %arg0 +// CHECK-DAG: %[[ISNAN:.+]] = arith.cmpf uno, %arg0, %arg0 // CHECK-DAG: %[[SIGN:.+]] = arith.shrui %[[BITCAST:.+]], %[[C31]] -// CHECK-DAG: %[[ROUND:.+]] = arith.subi %[[C32768]], %[[SIGN]] -// CHECK-DAG: %[[MANTISSA:.+]] = arith.andi %[[BITCAST]], %[[C8388607]] -// CHECK-DAG: %[[ROUNDED:.+]] = arith.addi %[[MANTISSA]], %[[ROUND]] -// CHECK-DAG: %[[ROLL:.+]] = arith.shrui %[[ROUNDED]], %[[C23]] -// CHECK-DAG: %[[SHR:.+]] = arith.shrui %[[ROUNDED]], %[[ROLL]] -// CHECK-DAG: %[[EXP:.+]] = arith.andi %0, %[[C2139095040]] -// CHECK-DAG: %[[EXPROUND:.+]] = arith.addi %[[EXP]], %[[ROUNDED]] -// CHECK-DAG: %[[EXPROLL:.+]] = arith.andi %[[EXPROUND]], %[[C2139095040]] -// CHECK-DAG: %[[EXPMAX:.+]] = arith.cmpi uge, %[[EXP]], %[[C2130706432]] -// CHECK-DAG: %[[EXPNEW:.+]] = arith.select %[[EXPMAX]], %[[EXP]], %[[EXPROLL]] -// CHECK-DAG: %[[OVERFLOW_B:.+]] = arith.trunci %[[ROLL]] -// CHECK-DAG: %[[KEEP_MAN:.+]] = arith.andi %[[EXPMAX]], %[[OVERFLOW_B]] -// CHECK-DAG: %[[MANNEW:.+]] = arith.select %[[KEEP_MAN]], %[[MANTISSA]], %[[SHR]] -// CHECK-DAG: %[[NEWSIGN:.+]] = arith.shli %[[SIGN]], %[[C31]] -// CHECK-DAG: %[[WITHEXP:.+]] = arith.ori %[[NEWSIGN]], %[[EXPNEW]] -// CHECK-DAG: %[[WITHMAN:.+]] = arith.ori %[[WITHEXP]], %[[MANNEW]] -// CHECK-DAG: %[[SHIFT:.+]] = arith.shrui %[[WITHMAN]], %[[C16]] -// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHIFT]] +// CHECK-DAG: %[[SIGNBIT:.+]] = arith.trunci %[[SIGN]] +// CHECK-DAG: %[[NANVAL:.+]] = arith.select %[[SIGNBIT]], %[[C_64]], %[[C32704]] +// CHECK-DAG: %[[LSB_PART:.+]] = arith.shrui %[[BITCAST]], %[[C16]] +// CHECK-DAG: %[[LSB:.+]] = arith.andi %[[LSB_PART]], %[[C1]] +// CHECK-DAG: %[[BIAS:.+]] = arith.addi %[[C32767]], %[[LSB]] +// CHECK-DAG: %[[BIASED:.+]] = arith.addi %[[BITCAST]], %[[BIAS]] +// CHECK-DAG: %[[SHIFT:.+]] = arith.shrui %[[BIASED]], %[[C16]] +// CHECK-DAG: %[[TRUNCTYPICAL:.+]] = arith.trunci %[[SHIFT]] +// CHECK-DAG: %[[TRUNC:.+]] = arith.select %[[ISNAN]], %[[NANVAL]], %[[TRUNCTYPICAL]] // CHECK-DAG: %[[RES:.+]] = arith.bitcast %[[TRUNC]] // CHECK: return %[[RES]] diff --git a/mlir/test/mlir-cpu-runner/expand-arith-ops.mlir b/mlir/test/mlir-cpu-runner/expand-arith-ops.mlir --- a/mlir/test/mlir-cpu-runner/expand-arith-ops.mlir +++ b/mlir/test/mlir-cpu-runner/expand-arith-ops.mlir @@ -14,6 +14,11 @@ func.func @main() { // CHECK: 1.00781 + %noRoundOneI = arith.constant 0x3f808001 : i32 + %noRoundOneF = arith.bitcast %noRoundOneI : i32 to f32 + call @trunc_bf16(%noRoundOneF): (f32) -> () + + // CHECK: 1 %roundOneI = arith.constant 0x3f808000 : i32 %roundOneF = arith.bitcast %roundOneI : i32 to f32 call @trunc_bf16(%roundOneF): (f32) -> () @@ -38,12 +43,12 @@ %neginff = arith.bitcast %neginfi : i32 to f32 call @trunc_bf16(%neginff): (f32) -> () - // CHECK-NEXT: 3.38953e+38 + // CHECK-NEXT: inf %bigi = arith.constant 0x7f7fffff : i32 %bigf = arith.bitcast %bigi : i32 to f32 call @trunc_bf16(%bigf): (f32) -> () - // CHECK-NEXT: -3.38953e+38 + // CHECK-NEXT: -inf %negbigi = arith.constant 0xff7fffff : i32 %negbigf = arith.bitcast %negbigi : i32 to f32 call @trunc_bf16(%negbigf): (f32) -> ()