diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -1153,6 +1153,7 @@ }]; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -54,6 +54,9 @@ /// Return the bitwidth of this float type. unsigned getWidth(); + /// Return the width of the mantissa of this type. + unsigned getFPMantissaWidth(); + /// Get or create a new FloatType with bitwidth scaled by `scale`. /// Return null if the scaled element type cannot be represented. FloatType scaleElementBitwidth(unsigned scale); diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -1392,6 +1392,299 @@ return BoolAttr::get(getContext(), val); } +class CmpFIntToFPConst final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, + bool isUnsigned) { + using namespace arith; + switch (pred) { + case CmpFPredicate::UEQ: + case CmpFPredicate::OEQ: + return CmpIPredicate::eq; + case CmpFPredicate::UGT: + case CmpFPredicate::OGT: + return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt; + case CmpFPredicate::UGE: + case CmpFPredicate::OGE: + return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge; + case CmpFPredicate::ULT: + case CmpFPredicate::OLT: + return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt; + case CmpFPredicate::ULE: + case CmpFPredicate::OLE: + return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle; + case CmpFPredicate::UNE: + case CmpFPredicate::ONE: + return CmpIPredicate::ne; + default: + llvm_unreachable("Unexpected predicate!"); + } + } + + LogicalResult matchAndRewrite(CmpFOp op, + PatternRewriter &rewriter) const override { + FloatAttr flt; + if (!matchPattern(op.getRhs(), m_Constant(&flt))) + return failure(); + + const APFloat &rhs = flt.getValue(); + + // Don't attempt to fold a nan. + if (rhs.isNaN()) + return failure(); + + // Get the width of the mantissa. We don't want to hack on conversions that + // might lose information from the integer, e.g. "i64 -> float" + FloatType floatTy = op.getRhs().getType().cast(); + int mantissaWidth = floatTy.getFPMantissaWidth(); + if (mantissaWidth <= 0) + return failure(); + + bool isUnsigned; + Value intVal; + + if (auto si = op.getLhs().getDefiningOp()) { + isUnsigned = false; + intVal = si.getIn(); + } else if (auto ui = op.getLhs().getDefiningOp()) { + isUnsigned = true; + intVal = ui.getIn(); + } else { + return failure(); + } + + // Check to see that the input is converted from an integer type that is + // small enough that preserves all bits. + auto intTy = intVal.getType().cast(); + auto intWidth = intTy.getWidth(); + + // Number of bits representing values, as opposed to the sign + auto valueBits = isUnsigned ? intWidth : (intWidth - 1); + + // Following test does NOT adjust intWidth downwards for signed inputs, + // because the most negative value still requires all the mantissa bits + // to distinguish it from one less than that value. + if ((int)intWidth > mantissaWidth) { + // Conversion would lose accuracy. Check if loss can impact comparison. + int exponent = ilogb(rhs); + if (exponent == APFloat::IEK_Inf) { + int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics())); + if (maxExponent < (int)valueBits) { + // Conversion could create infinity. + return failure(); + } + } else { + // Note that if rhs is zero or NaN, then Exp is negative + // and first condition is trivially false. + if (mantissaWidth <= exponent && exponent <= (int)valueBits) { + // Conversion could affect comparison. + return failure(); + } + } + } + + // Convert to equivalent cmpi predicate + CmpIPredicate pred; + switch (op.getPredicate()) { + case CmpFPredicate::ORD: + // Int to fp conversion doesn't create a nan (ord checks neither is a nan) + rewriter.replaceOpWithNewOp(op, /*value=*/true, + /*width=*/1); + return success(); + case CmpFPredicate::UNO: + // Int to fp conversion doesn't create a nan (uno checks either is a nan) + rewriter.replaceOpWithNewOp(op, /*value=*/false, + /*width=*/1); + return success(); + default: + pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned); + break; + } + + if (!isUnsigned) { + // If the rhs value is > SignedMax, fold the comparison. This handles + // +INF and large values. + APFloat signedMax(rhs.getSemantics()); + signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true, + APFloat::rmNearestTiesToEven); + if (signedMax < rhs) { // smax < 13123.0 + if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt || + pred == CmpIPredicate::sle) + rewriter.replaceOpWithNewOp(op, /*value=*/true, + /*width=*/1); + else + rewriter.replaceOpWithNewOp(op, /*value=*/false, + /*width=*/1); + return success(); + } + } else { + // If the rhs value is > UnsignedMax, fold the comparison. This handles + // +INF and large values. + APFloat unsignedMax(rhs.getSemantics()); + unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false, + APFloat::rmNearestTiesToEven); + if (unsignedMax < rhs) { // umax < 13123.0 + if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult || + pred == CmpIPredicate::ule) + rewriter.replaceOpWithNewOp(op, /*value=*/true, + /*width=*/1); + else + rewriter.replaceOpWithNewOp(op, /*value=*/false, + /*width=*/1); + return success(); + } + } + + if (!isUnsigned) { + // See if the rhs value is < SignedMin. + APFloat signedMin(rhs.getSemantics()); + signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true, + APFloat::rmNearestTiesToEven); + if (signedMin > rhs) { // smin > 12312.0 + if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt || + pred == CmpIPredicate::sge) + rewriter.replaceOpWithNewOp(op, /*value=*/true, + /*width=*/1); + else + rewriter.replaceOpWithNewOp(op, /*value=*/false, + /*width=*/1); + return success(); + } + } else { + // See if the rhs value is < UnsignedMin. + APFloat unsignedMin(rhs.getSemantics()); + unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false, + APFloat::rmNearestTiesToEven); + if (unsignedMin > rhs) { // umin > 12312.0 + if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt || + pred == CmpIPredicate::uge) + rewriter.replaceOpWithNewOp(op, /*value=*/true, + /*width=*/1); + else + rewriter.replaceOpWithNewOp(op, /*value=*/false, + /*width=*/1); + return success(); + } + } + + // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or + // [0, UMAX], but it may still be fractional. See if it is fractional by + // casting the FP value to the integer value and back, checking for + // equality. Don't do this for zero, because -0.0 is not fractional. + bool ignored; + APSInt rhsInt(intWidth, isUnsigned); + if (APFloat::opInvalidOp == + rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) { + // Undefined behavior invoked - the destination type can't represent + // the input constant. + return failure(); + } + + if (!rhs.isZero()) { + APFloat apf(floatTy.getFloatSemantics(), + APInt::getZero(floatTy.getWidth())); + apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven); + + bool equal = apf == rhs; + if (!equal) { + // If we had a comparison against a fractional value, we have to adjust + // the compare predicate and sometimes the value. rhsInt is rounded + // towards zero at this point. + switch (pred) { + default: + llvm_unreachable("Unexpected integer comparison!"); + case CmpIPredicate::ne: // (float)int != 4.4 --> true + rewriter.replaceOpWithNewOp(op, /*value=*/true, + /*width=*/1); + return success(); + case CmpIPredicate::eq: // (float)int == 4.4 --> false + rewriter.replaceOpWithNewOp(op, /*value=*/false, + /*width=*/1); + return success(); + case CmpIPredicate::ule: + // (float)int <= 4.4 --> int <= 4 + // (float)int <= -4.4 --> false + if (rhs.isNegative()) { + rewriter.replaceOpWithNewOp(op, /*value=*/false, + /*width=*/1); + return success(); + } + break; + case CmpIPredicate::sle: + // (float)int <= 4.4 --> int <= 4 + // (float)int <= -4.4 --> int < -4 + if (rhs.isNegative()) + pred = CmpIPredicate::slt; + break; + case CmpIPredicate::ult: + // (float)int < -4.4 --> false + // (float)int < 4.4 --> int <= 4 + if (rhs.isNegative()) { + rewriter.replaceOpWithNewOp(op, /*value=*/false, + /*width=*/1); + return success(); + } + pred = CmpIPredicate::ule; + break; + case CmpIPredicate::slt: + // (float)int < -4.4 --> int < -4 + // (float)int < 4.4 --> int <= 4 + if (!rhs.isNegative()) + pred = CmpIPredicate::sle; + break; + case CmpIPredicate::ugt: + // (float)int > 4.4 --> int > 4 + // (float)int > -4.4 --> true + if (rhs.isNegative()) { + rewriter.replaceOpWithNewOp(op, /*value=*/true, + /*width=*/1); + return success(); + } + break; + case CmpIPredicate::sgt: + // (float)int > 4.4 --> int > 4 + // (float)int > -4.4 --> int >= -4 + if (rhs.isNegative()) + pred = CmpIPredicate::sge; + break; + case CmpIPredicate::uge: + // (float)int >= -4.4 --> true + // (float)int >= 4.4 --> int > 4 + if (rhs.isNegative()) { + rewriter.replaceOpWithNewOp(op, /*value=*/true, + /*width=*/1); + return success(); + } + pred = CmpIPredicate::ugt; + break; + case CmpIPredicate::sge: + // (float)int >= -4.4 --> int >= -4 + // (float)int >= 4.4 --> int > 4 + if (!rhs.isNegative()) + pred = CmpIPredicate::sgt; + break; + } + } + } + + // Lower this FP comparison into an appropriate integer version of the + // comparison. + rewriter.replaceOpWithNewOp( + op, pred, intVal, + rewriter.create( + op.getLoc(), intVal.getType(), + rewriter.getIntegerAttr(intVal.getType(), rhsInt))); + return success(); + } +}; + +void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -137,6 +137,10 @@ return FloatType(); } +unsigned FloatType::getFPMantissaWidth() { + return APFloat::semanticsPrecision(getFloatSemantics()); +} + //===----------------------------------------------------------------------===// // FunctionType //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -788,3 +788,86 @@ %res = arith.sitofp %c0 : i32 to f32 return %res : f32 } + +// ----- + +// Tests rewritten from https://github.com/llvm/llvm-project/blob/main/llvm/test/Transforms/InstCombine/2008-11-08-FCmp.ll +// When inst combining an FCMP with the LHS coming from a arith.uitofp instruction, we +// can lower it to signed ICMP instructions. + +// CHECK-LABEL: @test1( +// CHECK-SAME: %[[arg0:.+]]: +func @test1(%arg0: i32) -> i1 { + %cst = arith.constant 0.000000e+00 : f64 + %1 = arith.uitofp %arg0: i32 to f64 + %2 = arith.cmpf ole, %1, %cst : f64 + // CHECK: %[[c0:.+]] = arith.constant 0 : i32 + // CHECK: arith.cmpi ule, %[[arg0]], %[[c0]] : i32 + return %2 : i1 +} + +// CHECK-LABEL: @test2( +// CHECK-SAME: %[[arg0:.+]]: +func @test2(%arg0: i32) -> i1 { + %cst = arith.constant 0.000000e+00 : f64 + %1 = arith.uitofp %arg0: i32 to f64 + %2 = arith.cmpf olt, %1, %cst : f64 + return %2 : i1 + // CHECK: %[[c0:.+]] = arith.constant 0 : i32 + // CHECK: arith.cmpi ult, %[[arg0]], %[[c0]] : i32 +} + +// CHECK-LABEL: @test3( +// CHECK-SAME: %[[arg0:.+]]: +func @test3(%arg0: i32) -> i1 { + %cst = arith.constant 0.000000e+00 : f64 + %1 = arith.uitofp %arg0: i32 to f64 + %2 = arith.cmpf oge, %1, %cst : f64 + return %2 : i1 + // CHECK: %[[c0:.+]] = arith.constant 0 : i32 + // CHECK: arith.cmpi uge, %[[arg0]], %[[c0]] : i32 +} + +// CHECK-LABEL: @test4( +// CHECK-SAME: %[[arg0:.+]]: +func @test4(%arg0: i32) -> i1 { + %cst = arith.constant 0.000000e+00 : f64 + %1 = arith.uitofp %arg0: i32 to f64 + %2 = arith.cmpf ogt, %1, %cst : f64 + // CHECK: %[[c0:.+]] = arith.constant 0 : i32 + // CHECK: arith.cmpi ugt, %[[arg0]], %[[c0]] : i32 + return %2 : i1 +} + +// CHECK-LABEL: @test5( +func @test5(%arg0: i32) -> i1 { + %cst = arith.constant -4.400000e+00 : f64 + %1 = arith.uitofp %arg0: i32 to f64 + %2 = arith.cmpf ogt, %1, %cst : f64 + return %2 : i1 + // CHECK: %[[true:.+]] = arith.constant true + // CHECK: return %[[true]] : i1 +} + +// CHECK-LABEL: @test6( +func @test6(%arg0: i32) -> i1 { + %cst = arith.constant -4.400000e+00 : f64 + %1 = arith.uitofp %arg0: i32 to f64 + %2 = arith.cmpf olt, %1, %cst : f64 + return %2 : i1 + // CHECK: %[[false:.+]] = arith.constant false + // CHECK: return %[[false]] : i1 +} + +// Check that optimizing unsigned >= comparisons correctly distinguishes +// positive and negative constants. +// CHECK-LABEL: @test7( +// CHECK-SAME: %[[arg0:.+]]: +func @test7(%arg0: i32) -> i1 { + %cst = arith.constant 3.200000e+00 : f64 + %1 = arith.uitofp %arg0: i32 to f64 + %2 = arith.cmpf oge, %1, %cst : f64 + return %2 : i1 + // CHECK: %[[c3:.+]] = arith.constant 3 : i32 + // CHECK: arith.cmpi ugt, %[[arg0]], %[[c3]] : i32 +}