diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -126,9 +126,12 @@ return builder.create(loc, type, isPositive, abs, absNegate); } -/// Returns the offset of the value in `targetBits` representation. `srcIdx` is -/// an index into a 1-D array with each element having `sourceBits`. When -/// accessing an element in the array treating as having elements of +/// Returns the offset of the value in `targetBits` representation. +/// +/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`. +/// It's assumed to be non-negative. +/// +/// When accessing an element in the array treating as having elements of /// `targetBits`, multiple values are loaded in the same time. The method /// returns the offset where the `srcIdx` locates in the value. For example, if /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is @@ -144,7 +147,7 @@ IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits); auto srcBitsValue = builder.create(loc, targetType, srcBitsAttr); - auto m = builder.create(loc, srcIdx, idx); + auto m = builder.create(loc, srcIdx, idx); return builder.create(loc, targetType, m, srcBitsValue); } diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -762,7 +762,7 @@ // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] // CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32 // CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32 - // CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR2]] : i32 // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spv.constant 255 : i32 @@ -788,7 +788,7 @@ // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] // CHECK: %[[TWO2:.+]] = spv.constant 2 : i32 // CHECK: %[[SIXTEEN:.+]] = spv.constant 16 : i32 - // CHECK: %[[IDX:.+]] = spv.SMod %[[FLAT_IDX]], %[[TWO2]] : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[FLAT_IDX]], %[[TWO2]] : i32 // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32 // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spv.constant 65535 : i32 @@ -824,7 +824,7 @@ // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 // CHECK: %[[FOUR:.+]] = spv.constant 4 : i32 // CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32 - // CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR]] : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR]] : i32 // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 // CHECK: %[[MASK1:.+]] = spv.constant 255 : i32 // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 @@ -850,7 +850,7 @@ // CHECK: %[[FLAT_IDX:.+]] = spv.IAdd %[[OFFSET]], %[[UPDATE]] : i32 // CHECK: %[[TWO:.+]] = spv.constant 2 : i32 // CHECK: %[[SIXTEEN:.+]] = spv.constant 16 : i32 - // CHECK: %[[IDX:.+]] = spv.SMod %[[FLAT_IDX]], %[[TWO]] : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[FLAT_IDX]], %[[TWO]] : i32 // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32 // CHECK: %[[MASK1:.+]] = spv.constant 65535 : i32 // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 @@ -907,7 +907,7 @@ // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] // CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32 // CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32 - // CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR2]] : i32 // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spv.constant 255 : i32 @@ -934,7 +934,7 @@ // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 // CHECK: %[[FOUR:.+]] = spv.constant 4 : i32 // CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32 - // CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR]] : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR]] : i32 // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 // CHECK: %[[MASK1:.+]] = spv.constant 255 : i32 // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32