diff --git a/mlir/lib/Conversion/SPIRVCommon/Pattern.h b/mlir/lib/Conversion/SPIRVCommon/Pattern.h --- a/mlir/lib/Conversion/SPIRVCommon/Pattern.h +++ b/mlir/lib/Conversion/SPIRVCommon/Pattern.h @@ -30,7 +30,7 @@ if (!dstType) return failure(); if (SPIRVOp::template hasTrait() && - dstType != op.getType()) { + !op.getType().isIndex() && dstType != op.getType()) { return op.emitError( "bitwidth emulation is not implemented yet on unsigned op"); } diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir --- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -27,9 +27,9 @@ return } -// CHECK-LABEL: @scalar_srem +// CHECK-LABEL: @int32_scalar_srem // CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) -func.func @scalar_srem(%lhs: i32, %rhs: i32) { +func.func @int32_scalar_srem(%lhs: i32, %rhs: i32) { // CHECK: %[[LABS:.+]] = spv.GLSL.SAbs %[[LHS]] : i32 // CHECK: %[[RABS:.+]] = spv.GLSL.SAbs %[[RHS]] : i32 // CHECK: %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : i32 @@ -40,6 +40,36 @@ return } +// CHECK-LABEL: @index_scalar +func.func @index_scalar(%lhs: index, %rhs: index) { + // CHECK: spv.IAdd %{{.*}}, %{{.*}}: i32 + %0 = arith.addi %lhs, %rhs: index + // CHECK: spv.ISub %{{.*}}, %{{.*}}: i32 + %1 = arith.subi %lhs, %rhs: index + // CHECK: spv.IMul %{{.*}}, %{{.*}}: i32 + %2 = arith.muli %lhs, %rhs: index + // CHECK: spv.SDiv %{{.*}}, %{{.*}}: i32 + %3 = arith.divsi %lhs, %rhs: index + // CHECK: spv.UDiv %{{.*}}, %{{.*}}: i32 + %4 = arith.divui %lhs, %rhs: index + // CHECK: spv.UMod %{{.*}}, %{{.*}}: i32 + %5 = arith.remui %lhs, %rhs: index + return +} + +// CHECK-LABEL: @index_scalar_srem +// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) +func.func @index_scalar_srem(%lhs: index, %rhs: index) { + // CHECK: %[[LABS:.+]] = spv.GLSL.SAbs %[[LHS]] : i32 + // CHECK: %[[RABS:.+]] = spv.GLSL.SAbs %[[RHS]] : i32 + // CHECK: %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : i32 + // CHECK: %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : i32 + // CHECK: %[[NEG:.+]] = spv.SNegate %[[ABS]] : i32 + // CHECK: %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32 + %0 = arith.remsi %lhs, %rhs: index + return +} + // Check float unary operation conversions. // CHECK-LABEL: @float32_unary_scalar func.func @float32_unary_scalar(%arg0: f32) {