diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -39,6 +39,16 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Convert floating-point comparison operations to SPIR-V dialect. +class CmpFOpConversion final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + /// Convert compare operation to SPIR-V dialect. class CmpIOpConversion final : public SPIRVOpLowering { public: @@ -196,6 +206,46 @@ } //===----------------------------------------------------------------------===// +// CmpFOp +//===----------------------------------------------------------------------===// + +PatternMatchResult +CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + CmpFOpOperandAdaptor cmpFOpOperands(operands); + + switch (cmpFOp.getPredicate()) { +#define DISPATCH(cmpPredicate, spirvOp) \ + case cmpPredicate: \ + rewriter.replaceOpWithNewOp( \ + cmpFOp, cmpFOp.getResult()->getType(), cmpFOpOperands.lhs(), \ + cmpFOpOperands.rhs()); \ + return matchSuccess(); + + // Ordered. + DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp); + DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); + DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); + DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp); + DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); + DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp); + // Unordered. + DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp); + DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); + DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); + DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp); + DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); + DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp); + +#undef DISPATCH + + default: + break; + } + return matchFailure(); +} + +//===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// @@ -218,11 +268,12 @@ DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp); DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp); DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp); + DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp); + DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp); + DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp); + DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp); #undef DISPATCH - - default: - break; } return matchFailure(); } @@ -302,7 +353,7 @@ OwningRewritePatternList &patterns) { // Add patterns that lower operations into SPIR-V dialect. populateWithGenerated(context, &patterns); - patterns.insert, IntegerOpConversion, IntegerOpConversion, diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -143,6 +143,39 @@ } //===----------------------------------------------------------------------===// +// std.cmpf +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @cmpf +func @cmpf(%arg0 : f32, %arg1 : f32) { + // CHECK: spv.FOrdEqual + %1 = cmpf "oeq", %arg0, %arg1 : f32 + // CHECK: spv.FOrdGreaterThan + %2 = cmpf "ogt", %arg0, %arg1 : f32 + // CHECK: spv.FOrdGreaterThanEqual + %3 = cmpf "oge", %arg0, %arg1 : f32 + // CHECK: spv.FOrdLessThan + %4 = cmpf "olt", %arg0, %arg1 : f32 + // CHECK: spv.FOrdLessThanEqual + %5 = cmpf "ole", %arg0, %arg1 : f32 + // CHECK: spv.FOrdNotEqual + %6 = cmpf "one", %arg0, %arg1 : f32 + // CHECK: spv.FUnordEqual + %7 = cmpf "ueq", %arg0, %arg1 : f32 + // CHECK: spv.FUnordGreaterThan + %8 = cmpf "ugt", %arg0, %arg1 : f32 + // CHECK: spv.FUnordGreaterThanEqual + %9 = cmpf "uge", %arg0, %arg1 : f32 + // CHECK: spv.FUnordLessThan + %10 = cmpf "ult", %arg0, %arg1 : f32 + // CHECK: FUnordLessThanEqual + %11 = cmpf "ule", %arg0, %arg1 : f32 + // CHECK: spv.FUnordNotEqual + %12 = cmpf "une", %arg0, %arg1 : f32 + return +} + +//===----------------------------------------------------------------------===// // std.cmpi //===----------------------------------------------------------------------===// @@ -160,6 +193,14 @@ %4 = cmpi "sgt", %arg0, %arg1 : i32 // CHECK: spv.SGreaterThanEqual %5 = cmpi "sge", %arg0, %arg1 : i32 + // CHECK: spv.ULessThan + %6 = cmpi "ult", %arg0, %arg1 : i32 + // CHECK: spv.ULessThanEqual + %7 = cmpi "ule", %arg0, %arg1 : i32 + // CHECK: spv.UGreaterThan + %8 = cmpi "ugt", %arg0, %arg1 : i32 + // CHECK: spv.UGreaterThanEqual + %9 = cmpi "uge", %arg0, %arg1 : i32 return } diff --git a/mlir/test/Dialect/SPIRV/Serialization/logical-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/logical-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/logical-ops.mlir @@ -51,6 +51,33 @@ %0 = spv.ULessThanEqual %arg0, %arg1 : vector<4xi32> spv.Return } + func @cmpf(%arg0 : f32, %arg1 : f32) { + // CHECK: spv.FOrdEqual + %1 = spv.FOrdEqual %arg0, %arg1 : f32 + // CHECK: spv.FOrdGreaterThan + %2 = spv.FOrdGreaterThan %arg0, %arg1 : f32 + // CHECK: spv.FOrdGreaterThanEqual + %3 = spv.FOrdGreaterThanEqual %arg0, %arg1 : f32 + // CHECK: spv.FOrdLessThan + %4 = spv.FOrdLessThan %arg0, %arg1 : f32 + // CHECK: spv.FOrdLessThanEqual + %5 = spv.FOrdLessThanEqual %arg0, %arg1 : f32 + // CHECK: spv.FOrdNotEqual + %6 = spv.FOrdNotEqual %arg0, %arg1 : f32 + // CHECK: spv.FUnordEqual + %7 = spv.FUnordEqual %arg0, %arg1 : f32 + // CHECK: spv.FUnordGreaterThan + %8 = spv.FUnordGreaterThan %arg0, %arg1 : f32 + // CHECK: spv.FUnordGreaterThanEqual + %9 = spv.FUnordGreaterThanEqual %arg0, %arg1 : f32 + // CHECK: spv.FUnordLessThan + %10 = spv.FUnordLessThan %arg0, %arg1 : f32 + // CHECK: spv.FUnordLessThanEqual + %11 = spv.FUnordLessThanEqual %arg0, %arg1 : f32 + // CHECK: spv.FUnordNotEqual + %12 = spv.FUnordNotEqual %arg0, %arg1 : f32 + spv.Return + } } // -----