diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -661,6 +661,16 @@ Value mask = rewriter.create( loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); result = rewriter.create(loc, dstType, result, mask); + + // If the number is negative, we should do a sign extention. + IntegerAttr shiftValueAttr = + rewriter.getIntegerAttr(dstType, dstBits - srcBits); + Value shiftValue = + rewriter.create(loc, dstType, shiftValueAttr); + result = rewriter.create(loc, dstType, result, + shiftValue); + result = rewriter.create(loc, dstType, result, + shiftValue); rewriter.replaceOp(loadOp, result); assert(accessChainOp.use_empty()); diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -689,7 +689,10 @@ // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spv.constant 255 : i32 - // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.constant 24 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 %0 = load %arg0[] : memref return } @@ -710,7 +713,10 @@ // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32 // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spv.constant 65535 : i32 - // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.constant 16 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 %0 = load %arg0[%index] : memref<10xi16> return } @@ -824,7 +830,10 @@ // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spv.constant 255 : i32 - // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.constant 24 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 %0 = load %arg0[] : memref return }