diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -147,14 +147,44 @@ } /// Returns the shifted `targetBits`-bit value with the given offset. -Value shiftValue(Location loc, Value value, Value offset, Value mask, - int targetBits, OpBuilder &builder) { +static Value shiftValue(Location loc, Value value, Value offset, Value mask, + int targetBits, OpBuilder &builder) { Type targetType = builder.getIntegerType(targetBits); Value result = builder.create(loc, value, mask); return builder.create(loc, targetType, result, offset); } +/// Returns true if the operator is operating on unsigned integers. +/// TODO: Have a TreatOperandsAsUnsignedInteger trait and bake the information +/// to the ops themselves. +template +bool isUnsignedOp() { + return false; +} + +#define CHECK_UNSIGNED_OP(SPIRVOp) \ + template <> \ + bool isUnsignedOp() { \ + return true; \ + } + +CHECK_UNSIGNED_OP(spirv::AtomicUMaxOp); +CHECK_UNSIGNED_OP(spirv::AtomicUMinOp); +CHECK_UNSIGNED_OP(spirv::BitFieldUExtractOp); +CHECK_UNSIGNED_OP(spirv::ConvertUToFOp); +CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMaxOp); +CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMinOp); +CHECK_UNSIGNED_OP(spirv::UConvertOp); +CHECK_UNSIGNED_OP(spirv::UDivOp); +CHECK_UNSIGNED_OP(spirv::UGreaterThanEqualOp); +CHECK_UNSIGNED_OP(spirv::UGreaterThanOp); +CHECK_UNSIGNED_OP(spirv::ULessThanEqualOp); +CHECK_UNSIGNED_OP(spirv::ULessThanOp); +CHECK_UNSIGNED_OP(spirv::UModOp); + +#undef CHECK_UNSIGNED_OP + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -178,6 +208,10 @@ auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); + if (isUnsignedOp() && dstType != operation.getType()) { + return operation.emitError( + "bitwidth emulation is not implemented yet on unsigned op"); + } rewriter.template replaceOpWithNewOp(operation, dstType, operands, ArrayRef()); return success(); @@ -581,6 +615,11 @@ switch (cmpIOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ + if (isUnsignedOp() && \ + operandType != this->typeConverter.convertType(operandType)) { \ + return cmpIOp.emitError( \ + "bitwidth emulation is not implemented yet on unsigned op"); \ + } \ rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ cmpIOpOperands.lhs(), \ cmpIOpOperands.rhs()); \ @@ -661,6 +700,18 @@ Value mask = rewriter.create( loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); result = rewriter.create(loc, dstType, result, mask); + + // Apply sign extension on the loading value unconditionally. The signedness + // semantic is carried in the operator itself, we relies other pattern to + // handle the casting. + IntegerAttr shiftValueAttr = + rewriter.getIntegerAttr(dstType, dstBits - srcBits); + Value shiftValue = + rewriter.create(loc, dstType, shiftValueAttr); + result = rewriter.create(loc, dstType, result, + shiftValue); + result = rewriter.create(loc, dstType, result, + shiftValue); rewriter.replaceOp(loadOp, result); assert(accessChainOp.use_empty()); diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-std-to-spirv %s -o - | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-std-to-spirv -verify-diagnostics %s -o - | FileCheck %s //===----------------------------------------------------------------------===// // std arithmetic ops @@ -128,14 +128,12 @@ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> } { -// CHECK-LABEL: @int_vector234 -func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<3xi16>, %arg2: vector<4xi64>) { +// CHECK-LABEL: @int_vector23 +func @int_vector23(%arg0: vector<2xi8>, %arg1: vector<3xi16>) { // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<2xi32> %0 = divi_signed %arg0, %arg0: vector<2xi8> // CHECK: spv.SRem %{{.*}}, %{{.*}}: vector<3xi32> %1 = remi_signed %arg1, %arg1: vector<3xi16> - // CHECK: spv.UDiv %{{.*}}, %{{.*}}: vector<4xi32> - %2 = divi_unsigned %arg2, %arg2: vector<4xi64> return } @@ -152,6 +150,27 @@ // ----- +// Check that types are converted to 32-bit when no special capabilities that +// are not supported. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LEBEL: @int_vector4_invalid +func @int_vector4_invalid(%arg0: vector<4xi64>) { + // expected-error @+2 {{bitwidth emulation is not implemented yet on unsigned op}} + // expected-error @+1 {{op requires the same type for all operands and results}} + %0 = divi_unsigned %arg0, %arg0: vector<4xi64> + return +} + +} // end module + +// ----- + //===----------------------------------------------------------------------===// // std bit ops //===----------------------------------------------------------------------===// @@ -717,7 +736,10 @@ // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spv.constant 255 : i32 - // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.constant 24 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 %0 = load %arg0[] : memref return } @@ -738,7 +760,10 @@ // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32 // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spv.constant 65535 : i32 - // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.constant 16 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 %0 = load %arg0[%index] : memref<10xi16> return } @@ -852,7 +877,10 @@ // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spv.constant 255 : i32 - // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.constant 24 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 %0 = load %arg0[] : memref return }