Index: mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp =================================================================== --- mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -39,6 +39,16 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Convert floating-point comparison operations to SPIR-V dialect. +class CmpFOpConversion final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + /// Convert compare operation to SPIR-V dialect. class CmpIOpConversion final : public SPIRVOpLowering { public: @@ -195,6 +205,46 @@ return matchSuccess(); } +//===----------------------------------------------------------------------===// +// CmpFOp +//===----------------------------------------------------------------------===// + +PatternMatchResult +CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + CmpFOpOperandAdaptor cmpFOpOperands(operands); + + switch (cmpFOp.getPredicate()) { +#define DISPATCH(cmpPredicate, spirvOp) \ + case cmpPredicate: \ + rewriter.replaceOpWithNewOp( \ + cmpFOp, cmpFOp.getResult()->getType(), cmpFOpOperands.lhs(), \ + cmpFOpOperands.rhs()); \ + return matchSuccess(); + + // Ordered. + DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp); + DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); + DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); + DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp); + DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); + DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp); + // Unordered. + DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp); + DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); + DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); + DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp); + DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); + DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp); + +#undef DISPATCH + + default: + break; + } + return matchFailure(); +} + //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// @@ -218,6 +268,10 @@ DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp); DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp); DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp); + DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp); + DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp); + DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp); + DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp); #undef DISPATCH @@ -302,7 +356,7 @@ OwningRewritePatternList &patterns) { // Add patterns that lower operations into SPIR-V dialect. populateWithGenerated(context, &patterns); - patterns.insert, IntegerOpConversion, IntegerOpConversion, Index: mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir =================================================================== --- mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -94,6 +94,39 @@ return } +//===----------------------------------------------------------------------===// +// std.cmpf +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @cmpf +func @cmpf(%arg0 : f32, %arg1 : f32) { + // CHECK: spv.FOrdEqual + %1 = cmpf "oeq", %arg0, %arg1 : f32 + // CHECK: spv.FOrdGreaterThan + %2 = cmpf "ogt", %arg0, %arg1 : f32 + // CHECK: spv.FOrdGreaterThanEqual + %3 = cmpf "oge", %arg0, %arg1 : f32 + // CHECK: spv.FOrdLessThan + %4 = cmpf "olt", %arg0, %arg1 : f32 + // CHECK: spv.FOrdLessThanEqual + %5 = cmpf "ole", %arg0, %arg1 : f32 + // CHECK: spv.FOrdNotEqual + %6 = cmpf "one", %arg0, %arg1 : f32 + // CHECK: spv.FUnordEqual + %7 = cmpf "ueq", %arg0, %arg1 : f32 + // CHECK: spv.FUnordGreaterThan + %8 = cmpf "ugt", %arg0, %arg1 : f32 + // CHECK: spv.FUnordGreaterThanEqual + %9 = cmpf "uge", %arg0, %arg1 : f32 + // CHECK: spv.FUnordLessThan + %10 = cmpf "ult", %arg0, %arg1 : f32 + // CHECK: FUnordLessThanEqual + %11 = cmpf "ule", %arg0, %arg1 : f32 + // CHECK: spv.FUnordNotEqual + %12 = cmpf "une", %arg0, %arg1 : f32 + return +} + //===----------------------------------------------------------------------===// // std.cmpi //===----------------------------------------------------------------------===// @@ -112,6 +145,14 @@ %4 = cmpi "sgt", %arg0, %arg1 : i32 // CHECK: spv.SGreaterThanEqual %5 = cmpi "sge", %arg0, %arg1 : i32 + // CHECK: spv.ULessThan + %6 = cmpi "ult", %arg0, %arg1 : i32 + // CHECK: spv.ULessThanEqual + %7 = cmpi "ule", %arg0, %arg1 : i32 + // CHECK: spv.UGreaterThan + %8 = cmpi "ugt", %arg0, %arg1 : i32 + // CHECK: spv.UGreaterThanEqual + %9 = cmpi "uge", %arg0, %arg1 : i32 return }