diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -184,6 +184,16 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts integer compare operation on i1 type opearnds to SPIR-V ops. +class BoolCmpIOpPattern final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + LogicalResult + matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts integer compare operation to SPIR-V ops. class CmpIOpPattern final : public SPIRVOpLowering { public: @@ -454,10 +464,42 @@ //===----------------------------------------------------------------------===// LogicalResult +BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + CmpIOpOperandAdaptor cmpIOpOperands(operands); + + Type operandType = cmpIOp.lhs().getType(); + if (!operandType.isa() || + operandType.cast().getWidth() != 1) + return failure(); + + switch (cmpIOp.getPredicate()) { +#define DISPATCH(cmpPredicate, spirvOp) \ + case cmpPredicate: \ + rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ + cmpIOpOperands.lhs(), \ + cmpIOpOperands.rhs()); \ + return success(); + + DISPATCH(CmpIPredicate::eq, spirv::LogicalEqualOp); + DISPATCH(CmpIPredicate::ne, spirv::LogicalNotEqualOp); + +#undef DISPATCH + default:; + } + return failure(); +} + +LogicalResult CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { CmpIOpOperandAdaptor cmpIOpOperands(operands); + Type operandType = cmpIOp.lhs().getType(); + if (operandType.isa() && + operandType.cast().getWidth() == 1) + return failure(); + switch (cmpIOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ @@ -599,9 +641,10 @@ UnaryAndBinaryOpPattern, BitwiseOpPattern, BitwiseOpPattern, - ConstantCompositeOpPattern, ConstantScalarOpPattern, CmpFOpPattern, - CmpIOpPattern, LoadOpPattern, ReturnOpPattern, SelectOpPattern, - StoreOpPattern, TypeCastingOpPattern, + BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern, + CmpFOpPattern, CmpIOpPattern, LoadOpPattern, ReturnOpPattern, + SelectOpPattern, StoreOpPattern, + TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, XOrOpPattern>( context, typeConverter); diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -285,6 +285,15 @@ return } +// CHECK-LABEL: @boolcmpi +func @boolcmpi(%arg0 : i1, %arg1 : i1) { + // CHECK: spv.LogicalEqual + %0 = cmpi "eq", %arg0, %arg1 : i1 + // CHECK: spv.LogicalNotEqual + %1 = cmpi "ne", %arg0, %arg1 : i1 + return +} + } // end module // -----