diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3216,6 +3216,8 @@ def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>; def SPV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>; def SPV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>; +def SPV_OC_OpOrdered : I32EnumAttrCase<"OpOrdered", 162>; +def SPV_OC_OpUnordered : I32EnumAttrCase<"OpUnordered", 163>; def SPV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>; def SPV_OC_OpLogicalNotEqual : I32EnumAttrCase<"OpLogicalNotEqual", 165>; def SPV_OC_OpLogicalOr : I32EnumAttrCase<"OpLogicalOr", 166>; @@ -3330,14 +3332,15 @@ SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar, - SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, - SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, - SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, - SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, - SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, - SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, - SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, - SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, + SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpOrdered, SPV_OC_OpUnordered, + SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, + SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, + SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, + SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, + SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, + SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, + SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, + SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td @@ -688,6 +688,41 @@ // ----- +def SPV_OrderedOp : SPV_LogicalBinaryOp<"Ordered", SPV_Float, [Commutative]> { + let summary = [{ + Result is true if both x == x and y == y are true, where IEEE comparison + is used, otherwise result is false. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + x must be a scalar or vector of floating-point type. It must have the + same number of components as Result Type. + + y must have the same type as x. + + Results are computed per component. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + ordered-op ::= ssa-id `=` `spv.Ordered` ssa-use, ssa-use + ```mlir + + #### Example: + + ``` + %4 = spv.Ordered %0, %1 : f32 + %5 = spv.Ordered %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + def SPV_SGreaterThanOp : SPV_LogicalBinaryOp<"SGreaterThan", SPV_Integer, [UsableInSpecConstantOp, SignedOp]> { @@ -1007,6 +1042,41 @@ // ----- +def SPV_UnorderedOp : SPV_LogicalBinaryOp<"Unordered", SPV_Float, [Commutative]> { + let summary = [{ + Result is true if either x or y is an IEEE NaN, otherwise result is + false. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + x must be a scalar or vector of floating-point type. It must have the + same number of components as Result Type. + + y must have the same type as x. + + Results are computed per component. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + unordered-op ::= ssa-id `=` `spv.Unordered` ssa-use, ssa-use + ```mlir + + #### Example: + + ``` + %4 = spv.Unordered %0, %1 : f32 + %5 = spv.Unordered %2, %3 : vector<4xf32> + ``` + }]; +} + +// ----- + def SPV_ULessThanEqualOp : SPV_LogicalBinaryOp<"ULessThanEqual", SPV_Integer, [UnsignedOp, diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -714,6 +714,7 @@ DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp); DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp); + DISPATCH(CmpFPredicate::ORD, spirv::OrderedOp); // Unordered. DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp); DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); @@ -721,6 +722,7 @@ DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp); DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp); + DISPATCH(CmpFPredicate::UNO, spirv::UnorderedOp); #undef DISPATCH diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -286,6 +286,10 @@ %11 = cmpf ule, %arg0, %arg1 : f32 // CHECK: spv.FUnordNotEqual %12 = cmpf une, %arg0, %arg1 : f32 + // CHECK: spv.Ordered + %13 = cmpf ord, %arg0, %arg1 : f32 + // CHECK: spv.Unordered + %14 = cmpf uno, %arg0, %arg1 : f32 return } diff --git a/mlir/test/Target/SPIRV/logical-ops.mlir b/mlir/test/Target/SPIRV/logical-ops.mlir --- a/mlir/test/Target/SPIRV/logical-ops.mlir +++ b/mlir/test/Target/SPIRV/logical-ops.mlir @@ -76,6 +76,10 @@ %11 = spv.FUnordLessThanEqual %arg0, %arg1 : f32 // CHECK: spv.FUnordNotEqual %12 = spv.FUnordNotEqual %arg0, %arg1 : f32 + // CHECK: spv.Ordered + %13 = spv.Ordered %arg0, %arg1 : f32 + // CHECK: spv.Unordered + %14 = spv.Unordered %arg0, %arg1 : f32 spv.Return } } diff --git a/mlir/utils/spirv/define_inst.sh b/mlir/utils/spirv/define_inst.sh --- a/mlir/utils/spirv/define_inst.sh +++ b/mlir/utils/spirv/define_inst.sh @@ -42,7 +42,7 @@ python3 ${current_dir}/gen_spirv_dialect.py \ --op-td-path \ - ${current_dir}/../../include/mlir/Dialect/SPIRV/${file_name} \ + ${current_dir}/../../include/mlir/Dialect/SPIRV/IR/${file_name} \ --inst-category $baseclass --new-inst "$@" ${current_dir}/define_opcodes.sh "$@"