diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -845,7 +845,8 @@ #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ if (spirvOp::template hasTrait() && \ - srcType != dstType && !hasSameBitwidth(srcType, dstType)) { \ + !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \ + !hasSameBitwidth(srcType, dstType)) { \ return op.emitError( \ "bitwidth emulation is not implemented yet on unsigned op"); \ } \ diff --git a/mlir/lib/Conversion/SPIRVCommon/Pattern.h b/mlir/lib/Conversion/SPIRVCommon/Pattern.h --- a/mlir/lib/Conversion/SPIRVCommon/Pattern.h +++ b/mlir/lib/Conversion/SPIRVCommon/Pattern.h @@ -10,6 +10,7 @@ #define MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/FormatVariadic.h" @@ -34,9 +35,11 @@ } if (SPIRVOp::template hasTrait() && - !op.getType().isIndex() && dstType != op.getType()) { - return op.emitError( - "bitwidth emulation is not implemented yet on unsigned op"); + !getElementTypeOrSelf(op.getType()).isIndex() && + dstType != op.getType()) { + op.dump(); + return op.emitError("bitwidth emulation is not implemented yet on " + "unsigned op pattern version"); } rewriter.template replaceOpWithNewOp(op, dstType, adaptor.getOperands()); diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -179,6 +179,13 @@ return } +// CHECK-LABEL: @index_vector +func.func @index_vector(%arg0: vector<4xindex>) { + // CHECK: spirv.UMod %{{.*}}, %{{.*}}: vector<4xi32> + %0 = arith.remui %arg0, %arg0: vector<4xindex> + return +} + // CHECK-LABEL: @vector_srem // CHECK-SAME: (%[[LHS:.+]]: vector<3xi16>, %[[RHS:.+]]: vector<3xi16>) func.func @vector_srem(%arg0: vector<3xi16>, %arg1: vector<3xi16>) { @@ -417,6 +424,31 @@ return } +// CHECK-LABEL: @indexcmpi +func.func @indexcmpi(%arg0 : index, %arg1 : index) { + // CHECK: spirv.IEqual + %0 = arith.cmpi eq, %arg0, %arg1 : index + // CHECK: spirv.INotEqual + %1 = arith.cmpi ne, %arg0, %arg1 : index + // CHECK: spirv.SLessThan + %2 = arith.cmpi slt, %arg0, %arg1 : index + // CHECK: spirv.SLessThanEqual + %3 = arith.cmpi sle, %arg0, %arg1 : index + // CHECK: spirv.SGreaterThan + %4 = arith.cmpi sgt, %arg0, %arg1 : index + // CHECK: spirv.SGreaterThanEqual + %5 = arith.cmpi sge, %arg0, %arg1 : index + // CHECK: spirv.ULessThan + %6 = arith.cmpi ult, %arg0, %arg1 : index + // CHECK: spirv.ULessThanEqual + %7 = arith.cmpi ule, %arg0, %arg1 : index + // CHECK: spirv.UGreaterThan + %8 = arith.cmpi ugt, %arg0, %arg1 : index + // CHECK: spirv.UGreaterThanEqual + %9 = arith.cmpi uge, %arg0, %arg1 : index + return +} + // CHECK-LABEL: @vec1cmpi func.func @vec1cmpi(%arg0 : vector<1xi32>, %arg1 : vector<1xi32>) { // CHECK: spirv.ULessThan