diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -22,6 +22,7 @@ void populateExpandExp2FPattern(RewritePatternSet &patterns); void populateExpandPowFPattern(RewritePatternSet &patterns); void populateExpandRoundFPattern(RewritePatternSet &patterns); +void populateExpandRoundEvenPattern(RewritePatternSet &patterns); void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns); struct MathPolynomialApproximationOptions { diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -48,9 +48,14 @@ static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) { Type opType = operand.getType(); - Value fixedConvert = b.create(b.getI64Type(), operand); + Type i64Ty = b.getI64Type(); + if (auto shapedTy = dyn_cast(opType)) + i64Ty = shapedTy.clone(i64Ty); + Value fixedConvert = b.create(i64Ty, operand); Value fpFixedConvert = b.create(opType, fixedConvert); - return fpFixedConvert; + // The truncation does not preserve the sign when the truncated + // value is -0. So here the sign is copied again. + return b.create(fpFixedConvert, operand); } /// Expands tanh op into @@ -189,23 +194,59 @@ static LogicalResult convertRoundOp(math::RoundOp op, PatternRewriter &rewriter) { - ImplicitLocOpBuilder b(op->getLoc(), rewriter); + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); Value operand = op.getOperand(); Type opType = operand.getType(); + Type opEType = getElementTypeOrSelf(opType); - // Creating constants for later use. - Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); - Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter); - Value negHalf = createFloatConst(op->getLoc(), opType, -0.5, rewriter); + if (!opEType.isF32()) { + return rewriter.notifyMatchFailure(op, "not a round of f32."); + } - Value posCheck = - b.create(arith::CmpFPredicate::OGE, operand, zero); - Value incrValue = - b.create(op->getLoc(), posCheck, half, negHalf); - Value add = b.create(opType, operand, incrValue); + Type i32Ty = b.getI32Type(); + if (auto shapedTy = dyn_cast(opType)) + i32Ty = shapedTy.clone(i32Ty); + + Value half = createFloatConst(loc, opType, 0.5, b); + Value c23 = createIntConst(loc, i32Ty, 23, b); + Value c127 = createIntConst(loc, i32Ty, 127, b); + Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b); + Value incrValue = b.create(half, operand); + Value add = b.create(opType, operand, incrValue); Value fpFixedConvert = createTruncatedFPValue(add, b); - rewriter.replaceOp(op, fpFixedConvert); + + // There are three cases where adding 0.5 to the value and truncating by + // converting to an i64 does not result in the correct behavior: + // + // 1. Special values: +-inf and +-nan + // Casting these special values to i64 has undefined behavior. To identify + // these values, we use the fact that these values are the only float + // values with the maximum possible biased exponent. + // + // 2. Large values: 2^23 <= |x| <= INT_64_MAX + // Adding 0.5 to a float larger than or equal to 2^23 results in precision + // errors that sometimes round the value up and sometimes round the value + // down. For example: + // 8388608.0 + 0.5 = 8388608.0 + // 8388609.0 + 0.5 = 8388610.0 + // + // 3. Very large values: |x| > INT_64_MAX + // Casting to i64 a value greater than the max i64 value will overflow the + // i64 leading to wrong outputs. + // + // All three cases satisfy the property `biasedExp >= 23`. + Value operandBitcast = b.create(i32Ty, operand); + Value operandExp = b.create( + b.create(operandBitcast, c23), expMask); + Value operandBiasedExp = b.create(operandExp, c127); + Value isSpecialValOrLargeVal = + b.create(arith::CmpIPredicate::sge, operandBiasedExp, c23); + + Value result = b.create(isSpecialValOrLargeVal, operand, + fpFixedConvert); + rewriter.replaceOp(op, result); return success(); } @@ -253,6 +294,129 @@ return success(); } +// Convert `math.roundeven` into `math.round` + arith ops +static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, + PatternRewriter &rewriter) { + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + auto operand = op.getOperand(); + Type operandTy = operand.getType(); + Type resultTy = op.getType(); + Type operandETy = getElementTypeOrSelf(operandTy); + Type resultETy = getElementTypeOrSelf(resultTy); + + if (!operandETy.isF32() || !resultETy.isF32()) { + return rewriter.notifyMatchFailure(op, "not a roundeven of f32."); + } + + Type i32Ty = b.getI32Type(); + Type f32Ty = b.getF32Type(); + if (auto shapedTy = dyn_cast(operandTy)) { + i32Ty = shapedTy.clone(i32Ty); + f32Ty = shapedTy.clone(f32Ty); + } + + Value c1Float = createFloatConst(loc, f32Ty, 1.0, b); + Value c0 = createIntConst(loc, i32Ty, 0, b); + Value c1 = createIntConst(loc, i32Ty, 1, b); + Value cNeg1 = createIntConst(loc, i32Ty, -1, b); + Value c23 = createIntConst(loc, i32Ty, 23, b); + Value c31 = createIntConst(loc, i32Ty, 31, b); + Value c127 = createIntConst(loc, i32Ty, 127, b); + Value c2To22 = createIntConst(loc, i32Ty, 1 << 22, b); + Value c23Mask = createIntConst(loc, i32Ty, (1 << 23) - 1, b); + Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b); + + Value operandBitcast = b.create(i32Ty, operand); + Value round = b.create(operand); + Value roundBitcast = b.create(i32Ty, round); + + // Get biased exponents for operand and round(operand) + Value operandExp = b.create( + b.create(operandBitcast, c23), expMask); + Value operandBiasedExp = b.create(operandExp, c127); + Value roundExp = b.create( + b.create(roundBitcast, c23), expMask); + Value roundBiasedExp = b.create(roundExp, c127); + + auto safeShiftRight = [&](Value x, Value shift) -> Value { + // Clamp shift to valid range [0, 31] to avoid undefined behavior + Value clampedShift = b.create(shift, c0); + clampedShift = b.create(clampedShift, c31); + return b.create(x, clampedShift); + }; + + auto maskMantissa = [&](Value mantissa, + Value mantissaMaskRightShift) -> Value { + Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift); + return b.create(mantissa, shiftedMantissaMask); + }; + + // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring + // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers + // with `biasedExp > 23` (numbers where there is not enough precision to store + // decimals) are always even, and they satisfy the even condition trivially + // since the mantissa without all its bits is zero. The even condition + // is also true for +-0, since they have `biasedExp = -127` and the entire + // mantissa is zero. The case of +-1 has to be handled separately. Here + // we identify these values by noting that +-1 are the only whole numbers with + // `biasedExp == 0`. + // + // The special values +-inf and +-nan also satisfy the same property that + // whole non-unit even numbers satisfy. In particular, the special values have + // `biasedExp > 23`, so they get treated as large numbers with no room for + // decimals, which are always even. + Value roundBiasedExpEq0 = + b.create(arith::CmpIPredicate::eq, roundBiasedExp, c0); + Value roundBiasedExpMinus1 = b.create(roundBiasedExp, c1); + Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1); + Value roundIsNotEvenOrSpecialVal = b.create( + arith::CmpIPredicate::ne, roundMaskedMantissa, c0); + roundIsNotEvenOrSpecialVal = + b.create(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0); + + // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive + // integers if the bit at index `biasedExp` starting from the left in the + // mantissa is 1 and all the bits to the right are zero. Values with + // `biasedExp >= 23` don't have decimals, so they are never halfway. The + // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`, + // so these are handled separately. In particular, if `biasedExp == -1`, the + // value is halfway if the entire mantissa is zero. + Value operandBiasedExpEqNeg1 = b.create( + arith::CmpIPredicate::eq, operandBiasedExp, cNeg1); + Value expectedOperandMaskedMantissa = b.create( + operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp)); + Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp); + Value operandIsHalfway = + b.create(arith::CmpIPredicate::eq, operandMaskedMantissa, + expectedOperandMaskedMantissa); + // Ensure `biasedExp` is in the valid range for half values. + Value operandBiasedExpGeNeg1 = b.create( + arith::CmpIPredicate::sge, operandBiasedExp, cNeg1); + Value operandBiasedExpLt23 = + b.create(arith::CmpIPredicate::slt, operandBiasedExp, c23); + operandIsHalfway = + b.create(operandIsHalfway, operandBiasedExpLt23); + operandIsHalfway = + b.create(operandIsHalfway, operandBiasedExpGeNeg1); + + // Adjust rounded operand with `round(operand) - sign(operand)` to correct the + // case where `round` rounded in the opposite direction of `roundeven`. + Value sign = b.create(c1Float, operand); + Value roundShifted = b.create(round, sign); + // If the rounded value is even or a special value, we default to the behavior + // of `math.round`. + Value needsShift = + b.create(roundIsNotEvenOrSpecialVal, operandIsHalfway); + Value result = b.create(needsShift, roundShifted, round); + // The `x - sign` adjustment does not preserve the sign when we are adjusting + // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is + // rounded to -0.0. + result = b.create(result, operand); + rewriter.replaceOp(op, result); + return success(); +} + void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) { patterns.add(convertCtlzOp); } @@ -288,3 +452,7 @@ void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) { patterns.add(convertFloorOp); } + +void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) { + patterns.add(convertRoundEvenOp); +} diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -141,9 +141,10 @@ // CHECK-DAG: [[CST_0:%.+]] = arith.constant -1.000 // CHECK-NEXT: [[CVTI:%.+]] = arith.fptosi [[ARG0]] // CHECK-NEXT: [[CVTF:%.+]] = arith.sitofp [[CVTI]] + // CHECK-NEXT: [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]] // CHECK-NEXT: [[COMP:%.+]] = arith.cmpf olt, [[ARG0]], [[CST]] // CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]] - // CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[CVTF]], [[INCR]] + // CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]] // CHECK-NEXT: return [[ADDF]] %ret = math.floor %a : f64 return %ret : f64 @@ -158,9 +159,10 @@ // CHECK-DAG: [[CST_0:%.+]] = arith.constant 1.000 // CHECK-NEXT: [[CVTI:%.+]] = arith.fptosi [[ARG0]] // CHECK-NEXT: [[CVTF:%.+]] = arith.sitofp [[CVTI]] - // CHECK-NEXT: [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[CVTF]] + // CHECK-NEXT: [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]] + // CHECK-NEXT: [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[COPYSIGN]] // CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]] - // CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[CVTF]], [[INCR]] + // CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]] // CHECK-NEXT: return [[ADDF]] %ret = math.ceil %a : f64 return %ret : f64 @@ -193,19 +195,26 @@ // ----- // CHECK-LABEL: func @roundf_func -// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 -func.func @roundf_func(%a: f64) -> f64 { - // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000 - // CHECK-DAG: [[CST_0:%.+]] = arith.constant 5.000000e-01 - // CHECK-DAG: [[CST_1:%.+]] = arith.constant -5.000000e-01 - // CHECK-DAG: [[COMP:%.+]] = arith.cmpf oge, [[ARG0]], [[CST]] - // CHECK-DAG: [[SEL:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST_1]] - // CHECK-DAG: [[ADDF:%.+]] = arith.addf [[ARG0]], [[SEL]] - // CHECK-DAG: [[CVTI:%.+]] = arith.fptosi [[ADDF]] - // CHECK-DAG: [[CVTF:%.+]] = arith.sitofp [[CVTI]] - // CHECK: return [[CVTF]] - %ret = math.round %a : f64 - return %ret : f64 +// CHECK-SAME: (%[[ARG0:.*]]: f32) -> f32 +func.func @roundf_func(%a: f32) -> f32 { + // CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 + // CHECK-DAG: %[[C23:.*]] = arith.constant 23 + // CHECK-DAG: %[[C127:.*]] = arith.constant 127 + // CHECK-DAG: %[[EXP_MASK:.*]] = arith.constant 255 + // CHECK-DAG: %[[SHIFT:.*]] = math.copysign %[[HALF]], %[[ARG0]] + // CHECK-DAG: %[[ARG_SHIFTED:.*]] = arith.addf %[[ARG0]], %[[SHIFT]] + // CHECK-DAG: %[[FIXED_CONVERT:.*]] = arith.fptosi %[[ARG_SHIFTED]] + // CHECK-DAG: %[[FP_FIXED_CONVERT_0:.*]] = arith.sitofp %[[FIXED_CONVERT]] + // CHECK-DAG: %[[FP_FIXED_CONVERT_1:.*]] = math.copysign %[[FP_FIXED_CONVERT_0]], %[[ARG_SHIFTED]] + // CHECK-DAG: %[[ARG_BITCAST:.*]] = arith.bitcast %[[ARG0]] : f32 to i32 + // CHECK-DAG: %[[ARG_BITCAST_SHIFTED:.*]] = arith.shrui %[[ARG_BITCAST]], %[[C23]] + // CHECK-DAG: %[[ARG_EXP:.*]] = arith.andi %[[ARG_BITCAST_SHIFTED]], %[[EXP_MASK]] + // CHECK-DAG: %[[ARG_BIASED_EXP:.*]] = arith.subi %[[ARG_EXP]], %[[C127]] + // CHECK-DAG: %[[IS_SPECIAL_VAL:.*]] = arith.cmpi sge, %[[ARG_BIASED_EXP]], %[[C23]] + // CHECK-DAG: %[[RESULT:.*]] = arith.select %[[IS_SPECIAL_VAL]], %[[ARG0]], %[[FP_FIXED_CONVERT_1]] + // CHECK: return %[[RESULT]] + %ret = math.round %a : f32 + return %ret : f32 } // ----- @@ -220,3 +229,105 @@ %ret = math.powf %a, %b : f64 return %ret : f64 } + +// ----- + +// CHECK-LABEL: func.func @roundeven +func.func @roundeven(%arg: f32) -> f32 { + %res = math.roundeven %arg : f32 + return %res : f32 +} + +// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 { +// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[C_NEG_1:.*]] = arith.constant -1 : i32 +// CHECK-DAG: %[[C_1_FLOAT:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[C_23:.*]] = arith.constant 23 : i32 +// CHECK-DAG: %[[C_31:.*]] = arith.constant 31 : i32 +// CHECK-DAG: %[[C_127:.*]] = arith.constant 127 : i32 +// CHECK-DAG: %[[C_4194304:.*]] = arith.constant 4194304 : i32 +// CHECK-DAG: %[[C_8388607:.*]] = arith.constant 8388607 : i32 +// CHECK-DAG: %[[EXP_MASK:.*]] = arith.constant 255 : i32 +// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 + +// CHECK: %[[OPERAND_BITCAST:.*]] = arith.bitcast %[[VAL_0]] : f32 to i32 + +// Calculate `math.round(operand)` using expansion pattern for `round` and +// bitcast result to i32 +// CHECK: %[[SHIFT:.*]] = math.copysign %[[HALF]], %[[VAL_0]] +// CHECK: %[[ARG_SHIFTED:.*]] = arith.addf %[[VAL_0]], %[[SHIFT]] +// CHECK: %[[FIXED_CONVERT:.*]] = arith.fptosi %[[ARG_SHIFTED]] +// CHECK: %[[FP_FIXED_CONVERT_0:.*]] = arith.sitofp %[[FIXED_CONVERT]] +// CHECK: %[[FP_FIXED_CONVERT_1:.*]] = math.copysign %[[FP_FIXED_CONVERT_0]], %[[ARG_SHIFTED]] +// CHECK: %[[ARG_BITCAST:.*]] = arith.bitcast %[[VAL_0]] : f32 to i32 +// CHECK: %[[ARG_BITCAST_SHIFTED:.*]] = arith.shrui %[[ARG_BITCAST]], %[[C_23]] +// CHECK: %[[ARG_EXP:.*]] = arith.andi %[[ARG_BITCAST_SHIFTED]], %[[EXP_MASK]] +// CHECK: %[[ARG_BIASED_EXP:.*]] = arith.subi %[[ARG_EXP]], %[[C_127]] +// CHECK: %[[IS_SPECIAL_VAL:.*]] = arith.cmpi sge, %[[ARG_BIASED_EXP]], %[[C_23]] +// CHECK: %[[ROUND:.*]] = arith.select %[[IS_SPECIAL_VAL]], %[[VAL_0]], %[[FP_FIXED_CONVERT_1]] +// CHECK: %[[ROUND_BITCAST:.*]] = arith.bitcast %[[ROUND]] : f32 to i32 + +// Get biased exponents of `round` and `operand` +// CHECK: %[[SHIFTED_OPERAND_BITCAST:.*]] = arith.shrui %[[OPERAND_BITCAST]], %[[C_23]] : i32 +// CHECK: %[[OPERAND_EXP:.*]] = arith.andi %[[SHIFTED_OPERAND_BITCAST]], %[[EXP_MASK]] : i32 +// CHECK: %[[OPERAND_BIASED_EXP:.*]] = arith.subi %[[OPERAND_EXP]], %[[C_127]] : i32 +// CHECK: %[[SHIFTED_ROUND_BITCAST:.*]] = arith.shrui %[[ROUND_BITCAST]], %[[C_23]] : i32 +// CHECK: %[[ROUND_EXP:.*]] = arith.andi %[[SHIFTED_ROUND_BITCAST]], %[[EXP_MASK]] : i32 +// CHECK: %[[ROUND_BIASED_EXP:.*]] = arith.subi %[[ROUND_EXP]], %[[C_127]] : i32 + +// Determine if `ROUND_BITCAST` is an even whole number or a special value +// +-inf, +-nan. +// Mask mantissa of `ROUND_BITCAST` with a mask shifted to the right by +// `ROUND_BIASED_EXP - 1` +// CHECK-DAG: %[[ROUND_BIASED_EXP_MINUS_1:.*]] = arith.subi %[[ROUND_BIASED_EXP]], %[[C_1]] : i32 +// CHECK-DAG: %[[CLAMPED_SHIFT_0:.*]] = arith.maxsi %[[ROUND_BIASED_EXP_MINUS_1]], %[[C_0]] : i32 +// CHECK-DAG: %[[CLAMPED_SHIFT_1:.*]] = arith.minsi %[[CLAMPED_SHIFT_0]], %[[C_31]] : i32 +// CHECK-DAG: %[[SHIFTED_MANTISSA_MASK_0:.*]] = arith.shrui %[[C_8388607]], %[[CLAMPED_SHIFT_1]] : i32 +// CHECK-DAG: %[[ROUND_MASKED_MANTISSA:.*]] = arith.andi %[[ROUND_BITCAST]], %[[SHIFTED_MANTISSA_MASK_0]] : i32 + +// `ROUND_BITCAST` is not even whole number or special value if masked +// mantissa is != 0 or `ROUND_BIASED_EXP == 0` +// CHECK-DAG: %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_0:.*]] = arith.cmpi ne, %[[ROUND_MASKED_MANTISSA]], %[[C_0]] : i32 +// CHECK-DAG: %[[ROUND_BIASED_EXP_EQ_0:.*]] = arith.cmpi eq, %[[ROUND_BIASED_EXP]], %[[C_0]] : i32 +// CHECK-DAG: %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_1:.*]] = arith.ori %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_0]], %[[ROUND_BIASED_EXP_EQ_0]] : i1 + +// Determine if operand is halfway between two integer values +// CHECK: %[[OPERAND_BIASED_EXP_EQ_NEG_1:.*]] = arith.cmpi eq, %[[OPERAND_BIASED_EXP]], %[[C_NEG_1]] : i32 +// CHECK: %[[CLAMPED_SHIFT_2:.*]] = arith.maxsi %[[OPERAND_BIASED_EXP]], %[[C_0]] : i32 +// CHECK: %[[CLAMPED_SHIFT_3:.*]] = arith.minsi %[[CLAMPED_SHIFT_2]], %[[C_31]] : i32 +// CHECK: %[[SHIFTED_2_TO_22:.*]] = arith.shrui %[[C_4194304]], %[[CLAMPED_SHIFT_3]] : i32 + +// A value with `0 <= BIASED_EXP < 23` is halfway between two consecutive +// integers if the bit at index `BIASED_EXP` starting from the left in the +// mantissa is 1 and all the bits to the right are zero. For the case where +// `BIASED_EXP == -1, the expected mantissa is all zeros. +// CHECK: %[[EXPECTED_OPERAND_MASKED_MANTISSA:.*]] = arith.select %[[OPERAND_BIASED_EXP_EQ_NEG_1]], %[[C_0]], %[[SHIFTED_2_TO_22]] : i32 + +// Mask mantissa of `OPERAND_BITCAST` with a mask shifted to the right by +// `OPERAND_BIASED_EXP` +// CHECK: %[[CLAMPED_SHIFT_4:.*]] = arith.maxsi %[[OPERAND_BIASED_EXP]], %[[C_0]] : i32 +// CHECK: %[[CLAMPED_SHIFT_5:.*]] = arith.minsi %[[CLAMPED_SHIFT_4]], %[[C_31]] : i32 +// CHECK: %[[SHIFTED_MANTISSA_MASK_1:.*]] = arith.shrui %[[C_8388607]], %[[CLAMPED_SHIFT_5]] : i32 +// CHECK: %[[OPERAND_MASKED_MANTISSA:.*]] = arith.andi %[[OPERAND_BITCAST]], %[[SHIFTED_MANTISSA_MASK_1]] : i32 + +// The operand is halfway between two integers if the masked mantissa is equal +// to the expected mantissa and the biased exponent is in the range +// [-1, 23). +// CHECK-DAG: %[[OPERAND_BIASED_EXP_GE_NEG_1:.*]] = arith.cmpi sge, %[[OPERAND_BIASED_EXP]], %[[C_NEG_1]] : i32 +// CHECK-DAG: %[[OPERAND_BIASED_EXP_LT_23:.*]] = arith.cmpi slt, %[[OPERAND_BIASED_EXP]], %[[C_23]] : i32 +// CHECK-DAG: %[[OPERAND_IS_HALFWAY_0:.*]] = arith.cmpi eq, %[[OPERAND_MASKED_MANTISSA]], %[[EXPECTED_OPERAND_MASKED_MANTISSA]] : i32 +// CHECK-DAG: %[[OPERAND_IS_HALFWAY_1:.*]] = arith.andi %[[OPERAND_IS_HALFWAY_0]], %[[OPERAND_BIASED_EXP_LT_23]] : i1 +// CHECK-DAG: %[[OPERAND_IS_HALFWAY_2:.*]] = arith.andi %[[OPERAND_IS_HALFWAY_1]], %[[OPERAND_BIASED_EXP_GE_NEG_1]] : i1 + +// Adjust rounded operand with `round(operand) - sign(operand)` to correct the +// case where `round` rounded in the oppositve direction of `roundeven`. +// CHECK: %[[SIGN:.*]] = math.copysign %[[C_1_FLOAT]], %[[VAL_0]] : f32 +// CHECK: %[[ROUND_SHIFTED:.*]] = arith.subf %[[ROUND]], %[[SIGN]] : f32 +// CHECK: %[[NEEDS_SHIFT:.*]] = arith.andi %[[ROUND_IS_NOT_EVEN_OR_SPECIAL_1]], %[[OPERAND_IS_HALFWAY_2]] : i1 +// CHECK: %[[RESULT:.*]] = arith.select %[[NEEDS_SHIFT]], %[[ROUND_SHIFTED]], %[[ROUND]] : f32 + +// The `x - sign` adjustment does not preserve the sign when we are adjusting the value -1 to -0. +// CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f32 + +// CHECK: return %[[COPYSIGN]] : f32 diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp --- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp +++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp @@ -45,6 +45,7 @@ populateExpandCeilFPattern(patterns); populateExpandPowFPattern(patterns); populateExpandRoundFPattern(patterns); + populateExpandRoundEvenPattern(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir --- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir +++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir @@ -19,37 +19,37 @@ %a = arith.constant 1.0 : f64 call @func_exp2f(%a) : (f64) -> () - // CHECK: 4 + // CHECK-NEXT: 4 %b = arith.constant 2.0 : f64 call @func_exp2f(%b) : (f64) -> () - // CHECK: 5.65685 + // CHECK-NEXT: 5.65685 %c = arith.constant 2.5 : f64 call @func_exp2f(%c) : (f64) -> () - // CHECK: 0.29730 + // CHECK-NEXT: 0.29730 %d = arith.constant -1.75 : f64 call @func_exp2f(%d) : (f64) -> () - // CHECK: 1.09581 + // CHECK-NEXT: 1.09581 %e = arith.constant 0.132 : f64 call @func_exp2f(%e) : (f64) -> () - // CHECK: inf + // CHECK-NEXT: inf %f1 = arith.constant 0.00 : f64 %f2 = arith.constant 1.00 : f64 %f = arith.divf %f2, %f1 : f64 call @func_exp2f(%f) : (f64) -> () - // CHECK: inf + // CHECK-NEXT: inf %g = arith.constant 5038939.0 : f64 call @func_exp2f(%g) : (f64) -> () - // CHECK: 0 + // CHECK-NEXT: 0 %neg_inf = arith.constant 0xff80000000000000 : f64 call @func_exp2f(%neg_inf) : (f64) -> () - // CHECK: inf + // CHECK-NEXT: inf %i = arith.constant 0x7fc0000000000000 : f64 call @func_exp2f(%i) : (f64) -> () return @@ -64,39 +64,119 @@ return } +func.func @func_roundf$bitcast_result_to_int(%a : f32) { + %b = math.round %a : f32 + %c = arith.bitcast %b : f32 to i32 + vector.print %c : i32 + return +} + +func.func @func_roundf$vector(%a : vector<1xf32>) { + %b = math.round %a : vector<1xf32> + vector.print %b : vector<1xf32> + return +} + func.func @roundf() { - // CHECK: 4 + // CHECK-NEXT: 4 %a = arith.constant 3.8 : f32 call @func_roundf(%a) : (f32) -> () - // CHECK: -4 + // CHECK-NEXT: -4 %b = arith.constant -3.8 : f32 call @func_roundf(%b) : (f32) -> () - // CHECK: 0 - %c = arith.constant 0.0 : f32 + // CHECK-NEXT: -4 + %c = arith.constant -4.2 : f32 call @func_roundf(%c) : (f32) -> () - // CHECK: -4 - %d = arith.constant -4.2 : f32 + // CHECK-NEXT: -495 + %d = arith.constant -495.0 : f32 call @func_roundf(%d) : (f32) -> () - // CHECK: -495 - %e = arith.constant -495.0 : f32 + // CHECK-NEXT: 495 + %e = arith.constant 495.0 : f32 call @func_roundf(%e) : (f32) -> () - // CHECK: 495 - %f = arith.constant 495.0 : f32 + // CHECK-NEXT: 9 + %f = arith.constant 8.5 : f32 call @func_roundf(%f) : (f32) -> () - // CHECK: 9 - %g = arith.constant 8.5 : f32 + // CHECK-NEXT: -9 + %g = arith.constant -8.5 : f32 call @func_roundf(%g) : (f32) -> () - // CHECK: -9 - %h = arith.constant -8.5 : f32 + // CHECK-NEXT: -0 + %h = arith.constant -0.4 : f32 call @func_roundf(%h) : (f32) -> () + // Special values: 0, -0, inf, -inf, nan, -nan + %cNeg0 = arith.constant -0.0 : f32 + %c0 = arith.constant 0.0 : f32 + %cInfInt = arith.constant 0x7f800000 : i32 + %cInf = arith.bitcast %cInfInt : i32 to f32 + %cNegInfInt = arith.constant 0xff800000 : i32 + %cNegInf = arith.bitcast %cNegInfInt : i32 to f32 + %cNanInt = arith.constant 0x7fc00000 : i32 + %cNan = arith.bitcast %cNanInt : i32 to f32 + %cNegNanInt = arith.constant 0xffc00000 : i32 + %cNegNan = arith.bitcast %cNegNanInt : i32 to f32 + + // CHECK-NEXT: -0 + call @func_roundf(%cNeg0) : (f32) -> () + // CHECK-NEXT: 0 + call @func_roundf(%c0) : (f32) -> () + // CHECK-NEXT: inf + call @func_roundf(%cInf) : (f32) -> () + // CHECK-NEXT: -inf + call @func_roundf(%cNegInf) : (f32) -> () + // Per IEEE 754-2008, sign is not required when printing a negative NaN, so + // print as an int to ensure input NaN is left unchanged. + // CHECK-NEXT: 2143289344 + // CHECK-NEXT: 2143289344 + call @func_roundf$bitcast_result_to_int(%cNan) : (f32) -> () + vector.print %cNanInt : i32 + // CHECK-NEXT: -4194304 + // CHECK-NEXT: -4194304 + call @func_roundf$bitcast_result_to_int(%cNegNan) : (f32) -> () + vector.print %cNegNanInt : i32 + + // Very large values (greater than INT_64_MAX) + %c2To100 = arith.constant 1.268e30 : f32 // 2^100 + // CHECK-NEXT: 1.268e+30 + call @func_roundf(%c2To100) : (f32) -> () + + // Values above and below 2^23 = 8388608 + %c8388606_5 = arith.constant 8388606.5 : f32 + %c8388607 = arith.constant 8388607.0 : f32 + %c8388607_5 = arith.constant 8388607.5 : f32 + %c8388608 = arith.constant 8388608.0 : f32 + %c8388609 = arith.constant 8388609.0 : f32 + + // Bitcast result to int to avoid printing in scientific notation, + // which does not display all significant digits. + + // CHECK-NEXT: 1258291198 + // hex: 0x4AFFFFFE + call @func_roundf$bitcast_result_to_int(%c8388606_5) : (f32) -> () + // CHECK-NEXT: 1258291198 + // hex: 0x4AFFFFFE + call @func_roundf$bitcast_result_to_int(%c8388607) : (f32) -> () + // CHECK-NEXT: 1258291200 + // hex: 0x4B000000 + call @func_roundf$bitcast_result_to_int(%c8388607_5) : (f32) -> () + // CHECK-NEXT: 1258291200 + // hex: 0x4B000000 + call @func_roundf$bitcast_result_to_int(%c8388608) : (f32) -> () + // CHECK-NEXT: 1258291201 + // hex: 0x4B000001 + call @func_roundf$bitcast_result_to_int(%c8388609) : (f32) -> () + + // Check that vector type works + %cVec = arith.constant dense<[0.5]> : vector<1xf32> + // CHECK-NEXT: ( 1 ) + call @func_roundf$vector(%cVec) : (vector<1xf32>) -> () + return } @@ -110,52 +190,237 @@ } func.func @powf() { - // CHECK: 16 + // CHECK-NEXT: 16 %a = arith.constant 4.0 : f64 %a_p = arith.constant 2.0 : f64 call @func_powff64(%a, %a_p) : (f64, f64) -> () - // CHECK: nan + // CHECK-NEXT: nan %b = arith.constant -3.0 : f64 %b_p = arith.constant 3.0 : f64 call @func_powff64(%b, %b_p) : (f64, f64) -> () - // CHECK: 2.343 + // CHECK-NEXT: 2.343 %c = arith.constant 2.343 : f64 %c_p = arith.constant 1.000 : f64 call @func_powff64(%c, %c_p) : (f64, f64) -> () - // CHECK: 0.176171 + // CHECK-NEXT: 0.176171 %d = arith.constant 4.25 : f64 %d_p = arith.constant -1.2 : f64 call @func_powff64(%d, %d_p) : (f64, f64) -> () - // CHECK: 1 + // CHECK-NEXT: 1 %e = arith.constant 4.385 : f64 %e_p = arith.constant 0.00 : f64 call @func_powff64(%e, %e_p) : (f64, f64) -> () - // CHECK: 6.62637 + // CHECK-NEXT: 6.62637 %f = arith.constant 4.835 : f64 %f_p = arith.constant 1.2 : f64 call @func_powff64(%f, %f_p) : (f64, f64) -> () - // CHECK: nan + // CHECK-NEXT: nan %g = arith.constant 0xff80000000000000 : f64 call @func_powff64(%g, %g) : (f64, f64) -> () - // CHECK: nan + // CHECK-NEXT: nan %h = arith.constant 0x7fffffffffffffff : f64 call @func_powff64(%h, %h) : (f64, f64) -> () - // CHECK: nan + // CHECK-NEXT: nan %i = arith.constant 1.0 : f64 call @func_powff64(%i, %h) : (f64, f64) -> () - // CHECK: inf + // CHECK-NEXT: inf %j = arith.constant 29385.0 : f64 %j_p = arith.constant 23598.0 : f64 - call @func_powff64(%j, %j_p) : (f64, f64) -> () + call @func_powff64(%j, %j_p) : (f64, f64) -> () + return +} + +// -------------------------------------------------------------------------- // +// roundeven. +// -------------------------------------------------------------------------- // + +func.func @func_roundeven(%a : f32) { + %b = math.roundeven %a : f32 + vector.print %b : f32 + return +} + +func.func @func_roundeven$bitcast_result_to_int(%a : f32) { + %b = math.roundeven %a : f32 + %c = arith.bitcast %b : f32 to i32 + vector.print %c : i32 + return +} + +func.func @func_roundeven$vector(%a : vector<1xf32>) { + %b = math.roundeven %a : vector<1xf32> + vector.print %b : vector<1xf32> + return +} + +func.func @roundeven() { + %c0_25 = arith.constant 0.25 : f32 + %c0_5 = arith.constant 0.5 : f32 + %c0_75 = arith.constant 0.75 : f32 + %c1 = arith.constant 1.0 : f32 + %c1_25 = arith.constant 1.25 : f32 + %c1_5 = arith.constant 1.5 : f32 + %c1_75 = arith.constant 1.75 : f32 + %c2 = arith.constant 2.0 : f32 + %c2_25 = arith.constant 2.25 : f32 + %c2_5 = arith.constant 2.5 : f32 + %c2_75 = arith.constant 2.75 : f32 + %c3 = arith.constant 3.0 : f32 + %c3_25 = arith.constant 3.25 : f32 + %c3_5 = arith.constant 3.5 : f32 + %c3_75 = arith.constant 3.75 : f32 + + %cNeg0_25 = arith.constant -0.25 : f32 + %cNeg0_5 = arith.constant -0.5 : f32 + %cNeg0_75 = arith.constant -0.75 : f32 + %cNeg1 = arith.constant -1.0 : f32 + %cNeg1_25 = arith.constant -1.25 : f32 + %cNeg1_5 = arith.constant -1.5 : f32 + %cNeg1_75 = arith.constant -1.75 : f32 + %cNeg2 = arith.constant -2.0 : f32 + %cNeg2_25 = arith.constant -2.25 : f32 + %cNeg2_5 = arith.constant -2.5 : f32 + %cNeg2_75 = arith.constant -2.75 : f32 + %cNeg3 = arith.constant -3.0 : f32 + %cNeg3_25 = arith.constant -3.25 : f32 + %cNeg3_5 = arith.constant -3.5 : f32 + %cNeg3_75 = arith.constant -3.75 : f32 + + // CHECK-NEXT: 0 + call @func_roundeven(%c0_25) : (f32) -> () + // CHECK-NEXT: 0 + call @func_roundeven(%c0_5) : (f32) -> () + // CHECK-NEXT: 1 + call @func_roundeven(%c0_75) : (f32) -> () + // CHECK-NEXT: 1 + call @func_roundeven(%c1) : (f32) -> () + // CHECK-NEXT: 1 + call @func_roundeven(%c1_25) : (f32) -> () + // CHECK-NEXT: 2 + call @func_roundeven(%c1_5) : (f32) -> () + // CHECK-NEXT: 2 + call @func_roundeven(%c1_75) : (f32) -> () + // CHECK-NEXT: 2 + call @func_roundeven(%c2) : (f32) -> () + // CHECK-NEXT: 2 + call @func_roundeven(%c2_25) : (f32) -> () + // CHECK-NEXT: 2 + call @func_roundeven(%c2_5) : (f32) -> () + // CHECK-NEXT: 3 + call @func_roundeven(%c2_75) : (f32) -> () + // CHECK-NEXT: 3 + call @func_roundeven(%c3) : (f32) -> () + // CHECK-NEXT: 3 + call @func_roundeven(%c3_25) : (f32) -> () + // CHECK-NEXT: 4 + call @func_roundeven(%c3_5) : (f32) -> () + // CHECK-NEXT: 4 + call @func_roundeven(%c3_75) : (f32) -> () + + // CHECK-NEXT: -0 + call @func_roundeven(%cNeg0_25) : (f32) -> () + // CHECK-NEXT: -0 + call @func_roundeven(%cNeg0_5) : (f32) -> () + // CHECK-NEXT: -1 + call @func_roundeven(%cNeg0_75) : (f32) -> () + // CHECK-NEXT: -1 + call @func_roundeven(%cNeg1) : (f32) -> () + // CHECK-NEXT: -1 + call @func_roundeven(%cNeg1_25) : (f32) -> () + // CHECK-NEXT: -2 + call @func_roundeven(%cNeg1_5) : (f32) -> () + // CHECK-NEXT: -2 + call @func_roundeven(%cNeg1_75) : (f32) -> () + // CHECK-NEXT: -2 + call @func_roundeven(%cNeg2) : (f32) -> () + // CHECK-NEXT: -2 + call @func_roundeven(%cNeg2_25) : (f32) -> () + // CHECK-NEXT: -2 + call @func_roundeven(%cNeg2_5) : (f32) -> () + // CHECK-NEXT: -3 + call @func_roundeven(%cNeg2_75) : (f32) -> () + // CHECK-NEXT: -3 + call @func_roundeven(%cNeg3) : (f32) -> () + // CHECK-NEXT: -3 + call @func_roundeven(%cNeg3_25) : (f32) -> () + // CHECK-NEXT: -4 + call @func_roundeven(%cNeg3_5) : (f32) -> () + // CHECK-NEXT: -4 + call @func_roundeven(%cNeg3_75) : (f32) -> () + + + // Special values: 0, -0, inf, -inf, nan, -nan + %cNeg0 = arith.constant -0.0 : f32 + %c0 = arith.constant 0.0 : f32 + %cInfInt = arith.constant 0x7f800000 : i32 + %cInf = arith.bitcast %cInfInt : i32 to f32 + %cNegInfInt = arith.constant 0xff800000 : i32 + %cNegInf = arith.bitcast %cNegInfInt : i32 to f32 + %cNanInt = arith.constant 0x7fc00000 : i32 + %cNan = arith.bitcast %cNanInt : i32 to f32 + %cNegNanInt = arith.constant 0xffc00000 : i32 + %cNegNan = arith.bitcast %cNegNanInt : i32 to f32 + + // CHECK-NEXT: -0 + call @func_roundeven(%cNeg0) : (f32) -> () + // CHECK-NEXT: 0 + call @func_roundeven(%c0) : (f32) -> () + // CHECK-NEXT: inf + call @func_roundeven(%cInf) : (f32) -> () + // CHECK-NEXT: -inf + call @func_roundeven(%cNegInf) : (f32) -> () + // Per IEEE 754-2008, sign is not required when printing a negative NaN, so + // print as an int to ensure input NaN is left unchanged. + // CHECK-NEXT: 2143289344 + // CHECK-NEXT: 2143289344 + call @func_roundeven$bitcast_result_to_int(%cNan) : (f32) -> () + vector.print %cNanInt : i32 + // CHECK-NEXT: -4194304 + // CHECK-NEXT: -4194304 + call @func_roundeven$bitcast_result_to_int(%cNegNan) : (f32) -> () + vector.print %cNegNanInt : i32 + + + // Values above and below 2^23 = 8388608 + %c8388606_5 = arith.constant 8388606.5 : f32 + %c8388607 = arith.constant 8388607.0 : f32 + %c8388607_5 = arith.constant 8388607.5 : f32 + %c8388608 = arith.constant 8388608.0 : f32 + %c8388609 = arith.constant 8388609.0 : f32 + + // Bitcast result to int to avoid printing in scientific notation, + // which does not display all significant digits. + + // CHECK-NEXT: 1258291196 + // hex: 0x4AFFFFFC + call @func_roundeven$bitcast_result_to_int(%c8388606_5) : (f32) -> () + // CHECK-NEXT: 1258291198 + // hex: 0x4AFFFFFE + call @func_roundeven$bitcast_result_to_int(%c8388607) : (f32) -> () + // CHECK-NEXT: 1258291200 + // hex: 0x4B000000 + call @func_roundeven$bitcast_result_to_int(%c8388607_5) : (f32) -> () + // CHECK-NEXT: 1258291200 + // hex: 0x4B000000 + call @func_roundeven$bitcast_result_to_int(%c8388608) : (f32) -> () + // CHECK-NEXT: 1258291201 + // hex: 0x4B000001 + call @func_roundeven$bitcast_result_to_int(%c8388609) : (f32) -> () + + + // Check that vector type works + %cVec = arith.constant dense<[0.5]> : vector<1xf32> + // CHECK-NEXT: ( 0 ) + call @func_roundeven$vector(%cVec) : (vector<1xf32>) -> () return } @@ -163,5 +428,6 @@ call @exp2f() : () -> () call @roundf() : () -> () call @powf() : () -> () + call @roundeven() : () -> () return }