diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h --- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h +++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h @@ -49,9 +49,29 @@ ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } -private: +protected: NamedAttrList convertedAttr; }; + +// Wrapper around AttrConvertFastMathToLLVM that also sets the "kinds" attribute +// to the bitmask specified in `Kinds`, which is used for converting operations +// that lower to llvm.is.fpclass. +template +class AttrConvertAddFpclassKinds + : public AttrConvertFastMathToLLVM { +public: + AttrConvertAddFpclassKinds(SourceOp op) + : AttrConvertFastMathToLLVM(op) { + convertedAttr.set( + "kinds", + IntegerAttr::get(IntegerType::get(op.getContext(), 32), Kinds)); + } + + ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } + +protected: + using AttrConvertFastMathToLLVM::convertedAttr; +}; } // namespace arith } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -83,6 +83,20 @@ attr-dict `:` type($result) }]; } +// Base class for floating point unary operations. +class Arith_FloatPredicateOp traits = []> : + Arith_Op, Pure, + TypesMatchWith<"result type has i1 element type and same shape as operands", + "operand", "result", "::getI1SameShape($_self)">], traits)>, + Arguments<(ins FloatLike:$operand, + DefaultValuedAttr< + Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath)>, + Results<(outs BoolLike:$result)> { + let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)? + attr-dict `:` type($operand) }]; +} + // Base class for arithmetic cast operations. Requires a single operand and // result. If either is a shaped type, then the other must be of the same shape. class Arith_CastOp { + let summary = "Returns true for IEEE NaN inputs"; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// IsInfOp +//===----------------------------------------------------------------------===// +def Arith_IsInfOp : Arith_FloatPredicateOp<"isinf"> { + let summary = "Returns true for infinite float inputs"; + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // ExtUIOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -120,17 +120,21 @@ }]; } -def LLVM_IsFPClass : LLVM_OneResultIntrOp<"is.fpclass", [], [0], [Pure]> { - let arguments = (ins LLVM_ScalarOrVectorOf:$in, I32Attr:$bit); +def LLVM_IsFPClass : LLVM_OneResultIntrOp<"is.fpclass", [], [0], [Pure], + /*requiresFastmath=*/1> { + let arguments = (ins LLVM_ScalarOrVectorOf:$in, + I32Attr:$kinds, + DefaultValuedAttr:$fastmathFlags); string mlirBuilder = [{ auto op = $_builder.create<$_qualCppClassName>($_location, - $_resultType, $in, $_int_attr($bit)); - $res = op; + $_resultType, $in, $_int_attr($kinds), nullptr); + moduleImport.setFastmathFlagsAttr(inst, op); + $res = op; }]; string llvmBuilder = [{ auto *inst = createIntrinsicCall( builder, llvm::Intrinsic::}] # llvmEnumName # [{, - {$in, builder.getInt32(op.getBit())}, + {$in, builder.getInt32(op.getKinds())}, }] # declTypes # [{); $res = inst; }]; diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -53,6 +53,16 @@ VectorConvertToLLVMPattern; using FPToUIOpLowering = VectorConvertToLLVMPattern; +template +using InfKinds = + arith::AttrConvertAddFpclassKinds<(1 << 2) | (1 << 9), SourceOp, TargetOp>; +using IsInfLowering = + VectorConvertToLLVMPattern; +template +using NanKinds = + arith::AttrConvertAddFpclassKinds<(1 << 0) | (1 << 1), SourceOp, TargetOp>; +using IsNanLowering = + VectorConvertToLLVMPattern; using MaxFOpLowering = VectorConvertToLLVMPattern; @@ -465,6 +475,8 @@ FPToUIOpLowering, IndexCastOpSILowering, IndexCastOpUILowering, + IsInfLowering, + IsNanLowering, MaxFOpLowering, MaxSIOpLowering, MaxUIOpLowering, diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -1094,6 +1094,8 @@ spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, TypeCastingOpPattern, ExtUII1Pattern, TypeCastingOpPattern, ExtSII1Pattern, TypeCastingOpPattern, diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1112,6 +1112,42 @@ }); } +static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { + auto boolAttr = BoolAttr::get(ctx, value); + ShapedType shapedType = llvm::dyn_cast_or_null(type); + if (!shapedType) + return boolAttr; + return DenseElementsAttr::get(shapedType, boolAttr); +} + +//===----------------------------------------------------------------------===// +// IsNanOp +//===----------------------------------------------------------------------===// +OpFoldResult IsNanOp::fold(FoldAdaptor adaptor) { + if (bitEnumContainsAll(getFastmath(), FastMathFlags::nnan)) + return getBoolAttribute(getType(), getContext(), false); + return constFoldCastOp( + adaptor.getOperands(), getType(), + [](const APFloat &x, bool &success) -> APInt { + success = true; + return APInt(1, x.isNaN()); + }); +} + +//===----------------------------------------------------------------------===// +// IsInfOp +//===----------------------------------------------------------------------===// +OpFoldResult IsInfOp::fold(FoldAdaptor adaptor) { + if (bitEnumContainsAll(getFastmath(), FastMathFlags::ninf)) + return getBoolAttribute(getType(), getContext(), false); + return constFoldCastOp( + adaptor.getOperands(), getType(), + [](const APFloat &x, bool &success) -> APInt { + success = true; + return APInt(1, x.isInfinity()); + }); +} + //===----------------------------------------------------------------------===// // Utility functions for verifying cast ops //===----------------------------------------------------------------------===// @@ -1650,14 +1686,6 @@ llvm_unreachable("unknown cmpi predicate kind"); } -static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { - auto boolAttr = BoolAttr::get(ctx, value); - ShapedType shapedType = llvm::dyn_cast_or_null(type); - if (!shapedType) - return boolAttr; - return DenseElementsAttr::get(shapedType, boolAttr); -} - static std::optional getIntegerWidth(Type t) { if (auto intType = llvm::dyn_cast(t)) { return intType.getWidth(); diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -179,8 +179,7 @@ Value select = rewriter.create(loc, cmp, lhs, rhs); // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'. - Value isNaN = rewriter.create(loc, arith::CmpFPredicate::UNO, - rhs, rhs); + Value isNaN = rewriter.create(loc, rhs, op.getFastmath()); rewriter.replaceOpWithNewOp(op, isNaN, rhs, select); return success(); } diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -71,6 +71,28 @@ return %0, %4 : f32, i32 } +// CHECK-LABEL: @float_pred_ops +func.func @float_pred_ops(%arg0: f32) { + // CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath, kinds = 516 : i32}> : (f32) -> i1 + arith.isinf %arg0 : f32 + // CHECK-NEXT: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath, kinds = 516 : i32}> : (f32) -> i1 + arith.isinf %arg0 fastmath : f32 + // CHECK-NEXT: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath, kinds = 3 : i32}> : (f32) -> i1 + arith.isnan %arg0 : f32 + // CHECK-NEXT: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath, kinds = 3 : i32}> : (f32) -> i1 + arith.isnan %arg0 fastmath : f32 + return +} + +// CHECK-LABEL: @vector_float_pred_ops +func.func @vector_float_pred_ops(%arg0: vector<4xf32>) { + // CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath, kinds = 516 : i32}> : (vector<4xf32>) -> vector<4xi1> + arith.isinf %arg0 : vector<4xf32> + // CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath, kinds = 3 : i32}> : (vector<4xf32>) -> vector<4xi1> + arith.isnan %arg0 : vector<4xf32> + return +} + // Checking conversion of index types to integers using i1, assuming no target // system would have a 1-bit address space. Otherwise, we would have had to // make this test dependent on the pointer size on the target system. diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -169,6 +169,16 @@ return } +// Check float predicate operation conversions +// CHECK-LABEL: @float32_predicate_scalar +func.func @float32_predicate_scalar(%arg0 : f32) { + // CHECK: spirv.IsNan %{{.*}}: f32 + %0 = arith.isnan %arg0 : f32 + // CHECK: spirv.IsInf %{{.*}}: f32 + %1 = arith.isinf %arg0 : f32 + return +} + // Check int vector types. // CHECK-LABEL: @int_vector234 func.func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<4xi64>) { diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2567,3 +2567,78 @@ %2 = arith.ori %arg0, %1 : index return %2 : index } + +// ----- + +// CHECK-LABEL: @foldIsNanFastmath +// CHECK-SAME: (%[[ARG:.+]]: f32) +// CHECK: %[[FALSE:.+]] = arith.constant false +// CHECK: return %[[FALSE]] +func.func @foldIsNanFastmath(%arg0: f32) -> i1 { + %0 = arith.isnan %arg0 fastmath : f32 + func.return %0 : i1 +} + +// CHECK-LABEL: @foldIsNan +// CHECK: %[[TRUE:.+]] = arith.constant true +// CHECK: return %[[TRUE]] +func.func @foldIsNan() -> i1 { + %cNan = arith.constant 0x7FFFFFFF : f32 + %0 = arith.isnan %cNan : f32 + func.return %0 : i1 +} + +// CHECK-LABEL: @foldNanIsNotNanWithFastmath +// CHECK: %[[FALSE:.+]] = arith.constant false +// CHECK: return %[[FALSE]] +func.func @foldNanIsNotNanWithFastmath() -> i1 { + %cNan = arith.constant 0x7FFFFFFF : f32 + %0 = arith.isnan %cNan fastmath : f32 + func.return %0 : i1 +} + + +// CHECK-LABEL: @foldIsNotNan +// CHECK: %[[FALSE:.+]] = arith.constant false +// CHECK: return %[[FALSE]] +func.func @foldIsNotNan() -> i1 { + %cNan = arith.constant 1.0 : f32 + %0 = arith.isnan %cNan : f32 + func.return %0 : i1 +} + +// CHECK-LABEL: @foldIsInfFastmath +// CHECK: %[[FALSE:.+]] = arith.constant false +// CHECK: return %[[FALSE]] +func.func @foldIsInfFastmath(%arg0: f32) -> i1 { + %0 = arith.isinf %arg0 fastmath : f32 + func.return %0 : i1 +} + +// CHECK-LABEL: @foldIsInf +// CHECK: %[[TRUE:.+]] = arith.constant true +// CHECK: return %[[TRUE]] +func.func @foldIsInf() -> i1 { + %cInf = arith.constant 0x7F800000 : f32 + %0 = arith.isinf %cInf : f32 + func.return %0 : i1 +} + +// CHECK-LABEL: @foldInfIsNotInfWithFastmath +// CHECK: %[[FALSE:.+]] = arith.constant false +// CHECK: return %[[FALSE]] +func.func @foldInfIsNotInfWithFastmath() -> i1 { + %cInf = arith.constant 0x7F800000 : f32 + %0 = arith.isinf %cInf fastmath : f32 + func.return %0 : i1 +} + + +// CHECK-LABEL: @foldIsNotInf +// CHECK: %[[FALSE:.+]] = arith.constant false +// CHECK: return %[[FALSE]] +func.func @foldIsNotInf() -> i1 { + %cInf = arith.constant 1.0 : f32 + %0 = arith.isinf %cInf : f32 + func.return %0 : i1 +} diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir --- a/mlir/test/Dialect/Arith/expand-ops.mlir +++ b/mlir/test/Dialect/Arith/expand-ops.mlir @@ -184,7 +184,7 @@ // CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32) // CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : f32 // CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[IS_NAN:.*]] = arith.isnan %[[RHS]] : f32 // CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32 // CHECK-NEXT: return %[[RESULT]] : f32 @@ -198,7 +198,7 @@ // CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>) // CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : vector<4xf16> // CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] -// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : vector<4xf16> +// CHECK-NEXT: %[[IS_NAN:.*]] = arith.isnan %[[RHS]] : vector<4xf16> // CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] // CHECK-NEXT: return %[[RESULT]] : vector<4xf16> @@ -213,7 +213,7 @@ // CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32) // CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ult, %[[LHS]], %[[RHS]] : f32 // CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[IS_NAN:.*]] = arith.isnan %[[RHS]] : f32 // CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32 // CHECK-NEXT: return %[[RESULT]] : f32 diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -577,6 +577,30 @@ return %0 : vector<[8]xf64> } +// CHECK-LABEL: test_isnan +func.func @test_isnan(%arg0 : f32) -> i1 { + %0 = arith.isnan %arg0 : f32 + func.return %0 : i1 +} + +// CHECK-LABEL: test_isnan_vector +func.func @test_isnan_vector(%arg0 : vector<2x2xf32>) -> vector<2x2xi1> { + %0 = arith.isnan %arg0 : vector<2x2xf32> + func.return %0 : vector<2x2xi1> +} + +// CHECK-LABEL: test_isinf +func.func @test_isinf(%arg0 : f32) -> i1 { + %0 = arith.isinf %arg0 : f32 + func.return %0 : i1 +} + +// CHECK-LABEL: test_isinf_vector +func.func @test_isinf_vector(%arg0 : vector<2x2xf32>) -> vector<2x2xi1> { + %0 = arith.isinf %arg0 : vector<2x2xf32> + func.return %0 : vector<2x2xi1> +} + // CHECK-LABEL: test_extui func.func @test_extui(%arg0 : i32) -> i64 { %0 = arith.extui %arg0 : i32 to i64 diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll --- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll +++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll @@ -17,9 +17,9 @@ ; CHECK-LABEL: llvm.func @fpclass_test define void @fpclass_test(float %0, <8 x float> %1) { - ; CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 0 : i32}> : (f32) -> i1 + ; CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath, kinds = 0 : i32}> : (f32) -> i1 %3 = call i1 @llvm.is.fpclass.f32(float %0, i32 0) - ; CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 1 : i32}> : (vector<8xf32>) -> vector<8xi1> + ; CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath, kinds = 1 : i32}> : (vector<8xf32>) -> vector<8xi1> %4 = call <8 x i1> @llvm.is.fpclass.v8f32(<8 x float> %1, i32 1) ret void } diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -18,7 +18,7 @@ // CHECK-LABEL: @fpclass_test llvm.func @fpclass_test(%arg0: f32) -> i1 { // CHECK: call i1 @llvm.is.fpclass - %0 = "llvm.intr.is.fpclass"(%arg0) <{bit = 0 : i32 }>: (f32) -> i1 + %0 = "llvm.intr.is.fpclass"(%arg0) <{fastmathFlags = #llvm.fastmath, kinds = 3 : i32 }>: (f32) -> i1 llvm.return %0 : i1 }