diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -147,14 +147,35 @@ } /// Returns the shifted `targetBits`-bit value with the given offset. -Value shiftValue(Location loc, Value value, Value offset, Value mask, - int targetBits, OpBuilder &builder) { +static Value shiftValue(Location loc, Value value, Value offset, Value mask, + int targetBits, OpBuilder &builder) { Type targetType = builder.getIntegerType(targetBits); Value result = builder.create(loc, value, mask); return builder.create(loc, targetType, result, offset); } +/// Returns true if the operator is operating on unsigned integers. +template +bool isUnsignedOp() { + return false; +} + +#define CHECK_UNSIGNED_OP(SPIRVOp) \ + template <> \ + bool isUnsignedOp() { \ + return true; \ + } + +CHECK_UNSIGNED_OP(spirv::UDivOp); +CHECK_UNSIGNED_OP(spirv::UModOp); +CHECK_UNSIGNED_OP(spirv::ULessThanOp); +CHECK_UNSIGNED_OP(spirv::ULessThanEqualOp); +CHECK_UNSIGNED_OP(spirv::UGreaterThanOp); +CHECK_UNSIGNED_OP(spirv::UGreaterThanEqualOp); + +#undef CHECK_UNSIGNED_OP + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -178,6 +199,10 @@ auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); + if (isUnsignedOp() && dstType != operation.getType()) { + operation.emitError( + "bitwidth emunation is not implemented yet on unsigned op"); + } rewriter.template replaceOpWithNewOp(operation, dstType, operands, ArrayRef()); return success(); @@ -581,6 +606,11 @@ switch (cmpIOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ + if (isUnsignedOp() && \ + operandType != this->typeConverter.convertType(operandType)) { \ + cmpIOp.emitError( \ + "bitwidth emunation is not implemented yet on unsigned op"); \ + } \ rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ cmpIOpOperands.lhs(), \ cmpIOpOperands.rhs()); \ @@ -661,6 +691,18 @@ Value mask = rewriter.create( loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); result = rewriter.create(loc, dstType, result, mask); + + // Apply sign extension on the loading value unconditionally. The signedness + // semantic is carried in the operator itself, we relies other pattern to + // handle the casting. + IntegerAttr shiftValueAttr = + rewriter.getIntegerAttr(dstType, dstBits - srcBits); + Value shiftValue = + rewriter.create(loc, dstType, shiftValueAttr); + result = rewriter.create(loc, dstType, result, + shiftValue); + result = rewriter.create(loc, dstType, result, + shiftValue); rewriter.replaceOp(loadOp, result); assert(accessChainOp.use_empty()); diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -128,14 +128,12 @@ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> } { -// CHECK-LABEL: @int_vector234 -func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<3xi16>, %arg2: vector<4xi64>) { +// CHECK-LABEL: @int_vector23 +func @int_vector23(%arg0: vector<2xi8>, %arg1: vector<3xi16>) { // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<2xi32> %0 = divi_signed %arg0, %arg0: vector<2xi8> // CHECK: spv.SRem %{{.*}}, %{{.*}}: vector<3xi32> %1 = remi_signed %arg1, %arg1: vector<3xi16> - // CHECK: spv.UDiv %{{.*}}, %{{.*}}: vector<4xi32> - %2 = divi_unsigned %arg2, %arg2: vector<4xi64> return } @@ -148,6 +146,13 @@ return } +// CHECK-LABEL: @int_vector4_invalid +func @int_vector4_invalid(%arg0: vector<4xi64>) { + // expected-error @+1 {{bitwidth emunation is not implemented yet on unsigned op}} + %0 = divi_unsigned %arg0, %arg0: vector<4xi64> + return +} + } // end module // ----- @@ -717,7 +722,10 @@ // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spv.constant 255 : i32 - // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.constant 24 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 %0 = load %arg0[] : memref return } @@ -738,7 +746,10 @@ // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32 // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spv.constant 65535 : i32 - // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.constant 16 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 %0 = load %arg0[%index] : memref<10xi16> return } @@ -852,7 +863,10 @@ // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spv.constant 255 : i32 - // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.constant 24 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 %0 = load %arg0[] : memref return }