diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -1136,6 +1136,7 @@ }]; let hasFolder = 1; + let hasCanonicalizer = 1; } #endif // ARITHMETIC_OPS diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -47,6 +47,9 @@ /// Return the bitwidth of this float type. unsigned getWidth(); + /// Return the width of the mantissa of this type. + unsigned getFPMantissaWidth(); + /// Get or create a new FloatType with bitwidth scaled by `scale`. /// Return null if the scaled element type cannot be represented. FloatType scaleElementBitwidth(unsigned scale); diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -1306,6 +1306,275 @@ return BoolAttr::get(getContext(), val); } +class CmpFIntToFPConst final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CmpFOp op, + PatternRewriter &rewriter) const override { + // See: https://llvm.org/doxygen/InstCombineCompares_8cpp_source.html#l06124 + + FloatAttr flt; + if (!matchPattern(op.getRhs(), m_Constant(&flt))) + return failure(); + + const APFloat &RHS = flt.getValue(); + + if (RHS.isNaN()) + return failure(); + + // Get the width of the mantissa. We don't want to hack on conversions that + // might lose information from the integer, e.g. "i64 -> float" + FloatType DestTy = op.getRhs().getType().cast(); + int MantissaWidth = DestTy.getFPMantissaWidth(); + if (MantissaWidth <= 0) + return failure(); + + bool LHSUnsigned; + Value intVal; + + if (auto si = op.getLhs().getDefiningOp()) { + LHSUnsigned = false; + intVal = si.getIn(); + } else if (auto ui = op.getLhs().getDefiningOp()) { + LHSUnsigned = true; + intVal = ui.getIn(); + } else + return failure(); + + auto IntTy = intVal.getType().cast(); + auto IntWidth = IntTy.getWidth(); + + CmpIPredicate Pred; + switch (op.getPredicate()) { + default: + llvm_unreachable("Unexpected predicate!"); + case CmpFPredicate::UEQ: + case CmpFPredicate::OEQ: + Pred = CmpIPredicate::eq; + break; + case CmpFPredicate::UGT: + case CmpFPredicate::OGT: + Pred = LHSUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt; + break; + case CmpFPredicate::UGE: + case CmpFPredicate::OGE: + Pred = LHSUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge; + break; + case CmpFPredicate::ULT: + case CmpFPredicate::OLT: + Pred = LHSUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt; + break; + case CmpFPredicate::ULE: + case CmpFPredicate::OLE: + Pred = LHSUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle; + break; + case CmpFPredicate::UNE: + case CmpFPredicate::ONE: + Pred = CmpIPredicate::ne; + break; + case CmpFPredicate::ORD: + rewriter.replaceOpWithNewOp(op, true, 1); + return success(); + case CmpFPredicate::UNO: + rewriter.replaceOpWithNewOp(op, false, 1); + return success(); + } + + // Check to see that the input is converted from an integer type that is + // small enough that preserves all bits. TODO: check here for "known" sign + // bits. This would allow us to handle (fptosi (x >>s 62) to float) if x is + // i64 f.e. + unsigned InputSize = IntTy.getWidth(); + + // Following test does NOT adjust InputSize downwards for signed inputs, + // because the most negative value still requires all the mantissa bits + // to distinguish it from one less than that value. + if ((int)InputSize > MantissaWidth) { + // Conversion would lose accuracy. Check if loss can impact comparison. + int Exp = ilogb(RHS); + if (Exp == APFloat::IEK_Inf) { + int MaxExponent = ilogb(APFloat::getLargest(RHS.getSemantics())); + if (MaxExponent < (int)InputSize - !LHSUnsigned) + // Conversion could create infinity. + return failure(); + } else { + // Note that if RHS is zero or NaN, then Exp is negative + // and first condition is trivially false. + if (MantissaWidth <= Exp && Exp <= (int)InputSize - !LHSUnsigned) + // Conversion could affect comparison. + return failure(); + } + } + + if (!LHSUnsigned) { + // If the RHS value is > SignedMax, fold the comparison. This handles + // +INF and large values. + APFloat SMax(RHS.getSemantics()); + SMax.convertFromAPInt(APInt::getSignedMaxValue(IntWidth), true, + APFloat::rmNearestTiesToEven); + if (SMax < RHS) { // smax < 13123.0 + if (Pred == CmpIPredicate::ne || Pred == CmpIPredicate::slt || + Pred == CmpIPredicate::sle) + rewriter.replaceOpWithNewOp(op, true, 1); + else + rewriter.replaceOpWithNewOp(op, false, 1); + return success(); + } + } else { + // If the RHS value is > UnsignedMax, fold the comparison. This handles + // +INF and large values. + APFloat UMax(RHS.getSemantics()); + UMax.convertFromAPInt(APInt::getMaxValue(IntWidth), false, + APFloat::rmNearestTiesToEven); + if (UMax < RHS) { // umax < 13123.0 + if (Pred == CmpIPredicate::ne || Pred == CmpIPredicate::ult || + Pred == CmpIPredicate::ule) + rewriter.replaceOpWithNewOp(op, true, 1); + else + rewriter.replaceOpWithNewOp(op, false, 1); + return success(); + } + } + + if (!LHSUnsigned) { + // See if the RHS value is < SignedMin. + APFloat SMin(RHS.getSemantics()); + SMin.convertFromAPInt(APInt::getSignedMinValue(IntWidth), true, + APFloat::rmNearestTiesToEven); + if (SMin > RHS) { // smin > 12312.0 + if (Pred == CmpIPredicate::ne || Pred == CmpIPredicate::sgt || + Pred == CmpIPredicate::sge) + rewriter.replaceOpWithNewOp(op, true, 1); + else + rewriter.replaceOpWithNewOp(op, false, 1); + return success(); + } + } else { + // See if the RHS value is < UnsignedMin. + APFloat UMin(RHS.getSemantics()); + UMin.convertFromAPInt(APInt::getMinValue(IntWidth), false, + APFloat::rmNearestTiesToEven); + if (UMin > RHS) { // umin > 12312.0 + if (Pred == CmpIPredicate::ne || Pred == CmpIPredicate::ugt || + Pred == CmpIPredicate::uge) + rewriter.replaceOpWithNewOp(op, true, 1); + else + rewriter.replaceOpWithNewOp(op, false, 1); + return success(); + } + } + + // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or + // [0, UMAX], but it may still be fractional. See if it is fractional by + // casting the FP value to the integer value and back, checking for + // equality. Don't do this for zero, because -0.0 is not fractional. + bool ignored; + APSInt RHSInt(IntWidth, LHSUnsigned); + if (APFloat::opInvalidOp == + RHS.convertToInteger(RHSInt, APFloat::rmTowardZero, &ignored)) { + // Undefined behavior invoked - the destination type can't represent + // the input constant. + return failure(); + } + + if (!RHS.isZero()) { + APFloat apf(DestTy.getFloatSemantics(), + APInt::getZero(DestTy.getWidth())); + apf.convertFromAPInt(RHSInt, !LHSUnsigned, APFloat::rmNearestTiesToEven); + + bool Equal = apf == RHS; + if (!Equal) { + // If we had a comparison against a fractional value, we have to adjust + // the compare predicate and sometimes the value. RHSC is rounded + // towards zero at this point. + switch (Pred) { + default: + llvm_unreachable("Unexpected integer comparison!"); + case CmpIPredicate::ne: // (float)int != 4.4 --> true + rewriter.replaceOpWithNewOp(op, true, 1); + return success(); + case CmpIPredicate::eq: // (float)int == 4.4 --> false + rewriter.replaceOpWithNewOp(op, false, 1); + return success(); + case CmpIPredicate::ule: + // (float)int <= 4.4 --> int <= 4 + // (float)int <= -4.4 --> false + if (RHS.isNegative()) { + rewriter.replaceOpWithNewOp(op, false, 1); + return success(); + } + break; + case CmpIPredicate::sle: + // (float)int <= 4.4 --> int <= 4 + // (float)int <= -4.4 --> int < -4 + if (RHS.isNegative()) + Pred = CmpIPredicate::slt; + break; + case CmpIPredicate::ult: + // (float)int < -4.4 --> false + // (float)int < 4.4 --> int <= 4 + if (RHS.isNegative()) { + rewriter.replaceOpWithNewOp(op, false, 1); + return success(); + } + Pred = CmpIPredicate::ule; + break; + case CmpIPredicate::slt: + // (float)int < -4.4 --> int < -4 + // (float)int < 4.4 --> int <= 4 + if (!RHS.isNegative()) + Pred = CmpIPredicate::sle; + break; + case CmpIPredicate::ugt: + // (float)int > 4.4 --> int > 4 + // (float)int > -4.4 --> true + if (RHS.isNegative()) { + rewriter.replaceOpWithNewOp(op, true, 1); + return success(); + } + break; + case CmpIPredicate::sgt: + // (float)int > 4.4 --> int > 4 + // (float)int > -4.4 --> int >= -4 + if (RHS.isNegative()) + Pred = CmpIPredicate::sge; + break; + case CmpIPredicate::uge: + // (float)int >= -4.4 --> true + // (float)int >= 4.4 --> int > 4 + if (RHS.isNegative()) { + rewriter.replaceOpWithNewOp(op, true, 1); + return success(); + } + Pred = CmpIPredicate::ugt; + break; + case CmpIPredicate::sge: + // (float)int >= -4.4 --> int >= -4 + // (float)int >= 4.4 --> int > 4 + if (!RHS.isNegative()) + Pred = CmpIPredicate::sgt; + break; + } + } + } + + // Lower this FP comparison into an appropriate integer version of the + // comparison. + rewriter.replaceOpWithNewOp( + op, Pred, intVal, + rewriter.create( + op.getLoc(), intVal.getType(), + rewriter.getIntegerAttr(intVal.getType(), RHSInt))); + return success(); + } +}; + +void arith::CmpFOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // Atomic Enum //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -142,6 +142,23 @@ return FloatType(); } +// See https://llvm.org/doxygen/Type_8cpp_source.html#l00196 +unsigned FloatType::getFPMantissaWidth() { + if (isa()) + return 8; + if (isa()) + return 11; + if (isa()) + return 24; + if (isa()) + return 53; + if (isa()) + return 64; + if (isa()) + return 113; + llvm_unreachable("unexpected float type"); +} + //===----------------------------------------------------------------------===// // FunctionType //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -308,17 +308,17 @@ return %add2 : index } -// CHECK-LABEL: @doubleAddSub1 +// CHECK-LABEL: @f64AddSub1 // CHECK-NEXT: return %arg0 -func @doubleAddSub1(%arg0: index, %arg1 : index) -> index { +func @f64AddSub1(%arg0: index, %arg1 : index) -> index { %sub = arith.subi %arg0, %arg1 : index %add = arith.addi %sub, %arg1 : index return %add : index } -// CHECK-LABEL: @doubleAddSub2 +// CHECK-LABEL: @f64AddSub2 // CHECK-NEXT: return %arg0 -func @doubleAddSub2(%arg0: index, %arg1 : index) -> index { +func @f64AddSub2(%arg0: index, %arg1 : index) -> index { %sub = arith.subi %arg0, %arg1 : index %add = arith.addi %arg1, %sub : index return %add : index @@ -691,3 +691,74 @@ %res = arith.maxf %const, %min : f32 return %res : f32 } + +// ----- + +// Tests rewritten from https://github.com/llvm/llvm-project/blob/main/llvm/test/Transforms/InstCombine/2008-11-08-FCmp.ll +// When inst combining an FCMP with the LHS coming from a arith.uitofp instruction, we +// can lower it to signed ICMP instructions. + +// CHECK-LABEL: @test1( +func @test1(%arg0: i32) -> i1 { + %cst = arith.constant 0.000000e+00 : f64 + %1 = arith.uitofp %arg0: i32 to f64 + %2 = arith.cmpf ole, %1, %cst : f64 +// CHECK: arith.cmpi ule, %arg0, %c0_i32 : i32 + return %2 : i1 +} + +// CHECK-LABEL: @test2( +func @test2(%arg0: i32) -> i1 { + %cst = arith.constant 0.000000e+00 : f64 + %1 = arith.uitofp %arg0: i32 to f64 + %2 = arith.cmpf olt, %1, %cst : f64 + return %2 : i1 +// CHECK: arith.cmpi ult, %arg0, %c0_i32 : i32 +} + +// CHECK-LABEL: @test3( +func @test3(%arg0: i32) -> i1 { + %cst = arith.constant 0.000000e+00 : f64 + %1 = arith.uitofp %arg0: i32 to f64 + %2 = arith.cmpf oge, %1, %cst : f64 + return %2 : i1 +// CHECK: arith.cmpi uge, %arg0, %c0_i32 : i32 +} + +// CHECK-LABEL: @test4( +func @test4(%arg0: i32) -> i1 { + %cst = arith.constant 0.000000e+00 : f64 + %1 = arith.uitofp %arg0: i32 to f64 + %2 = arith.cmpf ogt, %1, %cst : f64 +// CHECK: arith.cmpi ugt, %arg0, %c0_i32 : i32 + return %2 : i1 +} + +// CHECK-LABEL: @test5( +func @test5(%arg0: i32) -> i1 { + %cst = arith.constant -4.400000e+00 : f64 + %1 = arith.uitofp %arg0: i32 to f64 + %2 = arith.cmpf ogt, %1, %cst : f64 + return %2 : i1 +// CHECK: return %true : i1 +} + +// CHECK-LABEL: @test6( +func @test6(%arg0: i32) -> i1 { + %cst = arith.constant -4.400000e+00 : f64 + %1 = arith.uitofp %arg0: i32 to f64 + %2 = arith.cmpf olt, %1, %cst : f64 + return %2 : i1 +// CHECK: return %false : i1 +} + +// Check that optimizing unsigned >= comparisons correctly distinguishes +// positive and negative constants. +// CHECK-LABEL: @test7( +func @test7(%arg0: i32) -> i1 { + %cst = arith.constant 3.200000e+00 : f64 + %1 = arith.uitofp %arg0: i32 to f64 + %2 = arith.cmpf oge, %1, %cst : f64 + return %2 : i1 +// CHECK: arith.cmpi ugt, %arg0, %c3_i32 : i32 +}