diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp @@ -156,19 +156,16 @@ Value rhs = op.getRhs(); Location loc = op.getLoc(); + // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs'). + static_assert(pred == arith::CmpFPredicate::UGT || + pred == arith::CmpFPredicate::ULT); Value cmp = rewriter.create(loc, pred, lhs, rhs); Value select = rewriter.create(loc, cmp, lhs, rhs); - auto floatType = getElementTypeOrSelf(lhs.getType()).cast(); + // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'. Value isNaN = rewriter.create(loc, arith::CmpFPredicate::UNO, - lhs, rhs); - - Value nan = rewriter.create( - loc, APFloat::getQNaN(floatType.getFloatSemantics()), floatType); - if (VectorType vectorType = lhs.getType().dyn_cast()) - nan = rewriter.create(loc, vectorType, nan); - - rewriter.replaceOpWithNewOp(op, isNaN, nan, select); + rhs, rhs); + rewriter.replaceOpWithNewOp(op, isNaN, rhs, select); return success(); } }; @@ -226,8 +223,8 @@ CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter, - MaxMinFOpConverter, - MaxMinFOpConverter, + MaxMinFOpConverter, + MaxMinFOpConverter, MaxMinIOpConverter, MaxMinIOpConverter, MaxMinIOpConverter, diff --git a/mlir/test/Dialect/Arithmetic/expand-ops.mlir b/mlir/test/Dialect/Arithmetic/expand-ops.mlir --- a/mlir/test/Dialect/Arithmetic/expand-ops.mlir +++ b/mlir/test/Dialect/Arithmetic/expand-ops.mlir @@ -154,11 +154,10 @@ return %result : f32 } // CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ogt, %[[LHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : f32 // CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32 -// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32 +// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32 // CHECK-NEXT: return %[[RESULT]] : f32 // ----- @@ -169,12 +168,10 @@ return %result : vector<4xf16> } // CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ogt, %[[LHS]], %[[RHS]] : vector<4xf16> +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : vector<4xf16> // CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] -// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : vector<4xf16> -// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7E00 : f16 -// CHECK-NEXT: %[[SPLAT_NAN:.*]] = splat %[[NAN]] : vector<4xf16> -// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[SPLAT_NAN]], %[[SELECT]] +// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : vector<4xf16> +// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[RHS]], %[[SELECT]] // CHECK-NEXT: return %[[RESULT]] : vector<4xf16> // ----- @@ -185,11 +182,10 @@ return %result : f32 } // CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf olt, %[[LHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ult, %[[LHS]], %[[RHS]] : f32 // CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32 -// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32 +// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32 // CHECK-NEXT: return %[[RESULT]] : f32