diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3216,6 +3216,8 @@ def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>; def SPV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>; def SPV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>; +def SPV_OC_OpIsNan : I32EnumAttrCase<"OpIsNan", 156>; +def SPV_OC_OpIsInf : I32EnumAttrCase<"OpIsInf", 157>; def SPV_OC_OpOrdered : I32EnumAttrCase<"OpOrdered", 162>; def SPV_OC_OpUnordered : I32EnumAttrCase<"OpUnordered", 163>; def SPV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>; @@ -3332,15 +3334,15 @@ SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar, - SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpOrdered, SPV_OC_OpUnordered, - SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, - SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, - SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, - SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, - SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, - SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, - SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, - SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, + SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIsNan, SPV_OC_OpIsInf, SPV_OC_OpOrdered, + SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, + SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, + SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, + SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, + SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, + SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, + SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, + SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td @@ -41,6 +41,11 @@ SameOperandsAndResultShape])> { let parser = [{ return ::parseLogicalUnaryOp(parser, result); }]; let printer = [{ return ::printLogicalOp(getOperation(), p); }]; + + let builders = [ + OpBuilderDAG<(ins "Value":$value), + [{::buildLogicalUnaryOp($_builder, $_state, value);}]> + ]; } // ----- @@ -507,6 +512,70 @@ // ----- +def SPV_IsInfOp : SPV_LogicalUnaryOp<"IsInf", SPV_Float, []> { + let summary = "Result is true if x is an IEEE Inf, otherwise result is false"; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + x must be a scalar or vector of floating-point type. It must have the + same number of components as Result Type. + + Results are computed per component. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + isinf-op ::= ssa-id `=` `spv.IsInf` ssa-use + `:` float-scalar-vector-type + ``` + + #### Example: + + ```mlir + %2 = spv.IsInf %0: f32 + %3 = spv.IsInf %1: vector<4xi32> + ``` + }]; +} + +// ----- + +def SPV_IsNanOp : SPV_LogicalUnaryOp<"IsNan", SPV_Float, []> { + let summary = [{ + Result is true if x is an IEEE NaN, otherwise result is false. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + x must be a scalar or vector of floating-point type. It must have the + same number of components as Result Type. + + Results are computed per component. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + isnan-op ::= ssa-id `=` `spv.IsNan` ssa-use + `:` float-scalar-vector-type + ``` + + #### Example: + + ```mlir + %2 = spv.IsNan %0: f32 + %3 = spv.IsNan %1: vector<4xi32> + ``` + }]; +} + +// ----- + def SPV_LogicalAndOp : SPV_LogicalBinaryOp<"LogicalAnd", SPV_Bool, [Commutative, diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -386,6 +386,28 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts floating point NaN check to SPIR-V ops. This pattern requires +/// Kernel capability. +class CmpFOpNanKernelPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts floating point NaN check to SPIR-V ops. This pattern does not +/// require additional capability. +class CmpFOpNanNonePattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts integer compare operation on i1 type operands to SPIR-V ops. class BoolCmpIOpPattern final : public OpConversionPattern { public: @@ -730,7 +752,6 @@ DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp); DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp); - DISPATCH(CmpFPredicate::ORD, spirv::OrderedOp); // Unordered. DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp); DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); @@ -738,7 +759,6 @@ DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp); DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp); - DISPATCH(CmpFPredicate::UNO, spirv::UnorderedOp); #undef DISPATCH @@ -748,6 +768,47 @@ return failure(); } +LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( + CmpFOp cmpFOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + CmpFOpAdaptor cmpFOpOperands(operands); + + if (cmpFOp.getPredicate() == CmpFPredicate::ORD) { + rewriter.replaceOpWithNewOp(cmpFOp, cmpFOpOperands.lhs(), + cmpFOpOperands.rhs()); + return success(); + } + + if (cmpFOp.getPredicate() == CmpFPredicate::UNO) { + rewriter.replaceOpWithNewOp( + cmpFOp, cmpFOpOperands.lhs(), cmpFOpOperands.rhs()); + return success(); + } + + return failure(); +} + +LogicalResult CmpFOpNanNonePattern::matchAndRewrite( + CmpFOp cmpFOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + if (cmpFOp.getPredicate() != CmpFPredicate::ORD && + cmpFOp.getPredicate() != CmpFPredicate::UNO) + return failure(); + + CmpFOpAdaptor cmpFOpOperands(operands); + Location loc = cmpFOp.getLoc(); + + Value lhsIsNan = rewriter.create(loc, cmpFOpOperands.lhs()); + Value rhsIsNan = rewriter.create(loc, cmpFOpOperands.rhs()); + + Value replace = rewriter.create(loc, lhsIsNan, rhsIsNan); + if (cmpFOp.getPredicate() == CmpFPredicate::ORD) + replace = rewriter.create(loc, replace); + + rewriter.replaceOp(cmpFOp, replace); + return success(); +} + //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// @@ -1102,7 +1163,7 @@ SignedRemIOpPattern, XOrOpPattern, // Comparison patterns - BoolCmpIOpPattern, CmpFOpPattern, CmpIOpPattern, + BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern, // Constant patterns ConstantCompositeOpPattern, ConstantScalarOpPattern, @@ -1124,5 +1185,10 @@ TypeCastingOpPattern, TypeCastingOpPattern>(typeConverter, context); + + // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel + // capability is available. + patterns.insert(typeConverter, context, + /*benefit=*/2); } } // namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -900,6 +900,16 @@ state.addOperands({lhs, rhs}); } +static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state, + Value value) { + Type boolType = builder.getI1Type(); + if (auto vecType = value.getType().dyn_cast()) + boolType = VectorType::get(vecType.getShape(), boolType); + state.addTypes(boolType); + + state.addOperands(value); +} + //===----------------------------------------------------------------------===// // spv.AccessChainOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -301,6 +301,7 @@ // ----- +// With Kernel capability, we can convert NaN check to spv.Ordered/spv.Unordered. module attributes { spv.target_env = #spv.target_env<#spv.vce, {}> } { @@ -318,6 +319,31 @@ // ----- +// Without Kernel capability, we need to convert NaN check to spv.IsNan. +module attributes { + spv.target_env = #spv.target_env<#spv.vce, {}> +} { + +// CHECK-LABEL: @cmpf +// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32 +func @cmpf(%arg0 : f32, %arg1 : f32) { + // CHECK: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : f32 + // CHECK-NEXT: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : f32 + // CHECK-NEXT: %[[OR:.+]] = spv.LogicalOr %[[LHS_NAN]], %[[RHS_NAN]] : i1 + // CHECK-NEXT: %{{.+}} = spv.LogicalNot %[[OR]] : i1 + %0 = cmpf ord, %arg0, %arg1 : f32 + + // CHECK-NEXT: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : f32 + // CHECK-NEXT: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : f32 + // CHECK-NEXT: %{{.+}} = spv.LogicalOr %[[LHS_NAN]], %[[RHS_NAN]] : i1 + %1 = cmpf uno, %arg0, %arg1 : f32 + return +} + +} // end module + +// ----- + //===----------------------------------------------------------------------===// // std.cmpi //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir @@ -32,6 +32,40 @@ // ----- +//===----------------------------------------------------------------------===// +// spv.IsInf +//===----------------------------------------------------------------------===// + +func @isinf_scalar(%arg0: f32) -> i1 { + // CHECK: spv.IsInf {{.*}} : f32 + %0 = spv.IsInf %arg0 : f32 + return %0 : i1 +} + +func @isinf_vector(%arg0: vector<2xf32>) -> vector<2xi1> { + // CHECK: spv.IsInf {{.*}} : vector<2xf32> + %0 = spv.IsInf %arg0 : vector<2xf32> + return %0 : vector<2xi1> +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.IsNan +//===----------------------------------------------------------------------===// + +func @isnan_scalar(%arg0: f32) -> i1 { + // CHECK: spv.IsNan {{.*}} : f32 + %0 = spv.IsNan %arg0 : f32 + return %0 : i1 +} + +func @isnan_vector(%arg0: vector<2xf32>) -> vector<2xi1> { + // CHECK: spv.IsNan {{.*}} : vector<2xf32> + %0 = spv.IsNan %arg0 : vector<2xf32> + return %0 : vector<2xi1> +} + //===----------------------------------------------------------------------===// // spv.LogicalAnd //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/logical-ops.mlir b/mlir/test/Target/SPIRV/logical-ops.mlir --- a/mlir/test/Target/SPIRV/logical-ops.mlir +++ b/mlir/test/Target/SPIRV/logical-ops.mlir @@ -80,6 +80,10 @@ %13 = spv.Ordered %arg0, %arg1 : f32 // CHECK: spv.Unordered %14 = spv.Unordered %arg0, %arg1 : f32 + // CHCK: spv.IsNan + %15 = spv.IsNan %arg0 : f32 + // CHCK: spv.IsInf + %16 = spv.IsInf %arg1 : f32 spv.Return } }