diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -767,8 +767,7 @@ CmpIOpAdaptor cmpIOpOperands(operands); Type operandType = cmpIOp.lhs().getType(); - if (!operandType.isa() || - operandType.cast().getWidth() != 1) + if (!isBoolScalarOrVector(operandType)) return failure(); switch (cmpIOp.getPredicate()) { @@ -794,8 +793,7 @@ CmpIOpAdaptor cmpIOpOperands(operands); Type operandType = cmpIOp.lhs().getType(); - if (operandType.isa() && - operandType.cast().getWidth() == 1) + if (isBoolScalarOrVector(operandType)) return failure(); switch (cmpIOp.getPredicate()) { diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -327,6 +327,15 @@ return } +// CHECK-LABEL: @vecboolcmpi +func @vecboolcmpi(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { + // CHECK: spv.LogicalEqual + %0 = cmpi "eq", %arg0, %arg1 : vector<4xi1> + // CHECK: spv.LogicalNotEqual + %1 = cmpi "ne", %arg0, %arg1 : vector<4xi1> + return +} + } // end module // -----