diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md --- a/mlir/docs/Dialects/SPIR-V.md +++ b/mlir/docs/Dialects/SPIR-V.md @@ -252,12 +252,18 @@ Specification | Dialect :----------------------------------: | :-------------------------------: `OpTypeBool` | `i1` -`OpTypeInt ` | `i` `OpTypeFloat ` | `f` `OpTypeVector ` | `vector< x >` -Similarly, `mlir::NoneType` can be used for SPIR-V `OpTypeVoid`; builtin -function types can be used for SPIR-V `OpTypeFunction` types. +For integer types, the SPIR-V dialect supports all signedness semantics +(signless, signed, unsigned) in order to ease transformations from higher level +dialects. However, SPIR-V spec only defines two signedness semantics state: 0 +indicates unsigned, or no signedness semantics, 1 indicates signed semantics. So +both `iN` and `uiN` are serialized into the same `OpTypeInt N 0`. For +deserialization, we always treat `OpTypeInt N 0` as `iN`. + +`mlir::NoneType` is used for SPIR-V `OpTypeVoid`; builtin function types are +used for SPIR-V `OpTypeFunction` types. The SPIR-V dialect and defines the following dialect-specific types: diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -2945,6 +2945,17 @@ // SPIR-V type definitions //===----------------------------------------------------------------------===// +class IOrUI + : Type, + CPred<"$_self.isUnsignedInteger(" # width # ")">]>, + width # "-bit signless/unsigned integer"> { + int bitwidth = width; +} + +class SignlessOrUnsignedIntOfWidths widths> : + AnyTypeOf), + StrJoinInt.result # "-bit signless/unsigned integer">; + def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">; def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">; def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">; @@ -2953,8 +2964,8 @@ // See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types // for the definition of the following types and type categories. -def SPV_Void : TypeAlias; -def SPV_Bool : I<1>; +def SPV_Void : TypeAlias; +def SPV_Bool : TypeAlias; def SPV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>; def SPV_Float : FloatOfWidths<[16, 32, 64]>; def SPV_Float16or32 : FloatOfWidths<[16, 32]>; @@ -2977,6 +2988,8 @@ SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct ]>; +def SPV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>; + class SPV_ScalarOrVectorOf : AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>]>; @@ -2985,7 +2998,8 @@ class SPV_Vec4 : VectorOfLengthAndType<[4], [type]>; def SPV_IntVec4 : SPV_Vec4; -def SPV_I32Vec4 : SPV_Vec4; +def SPV_IOrUIVec4 : SPV_Vec4; +def SPV_Int32Vec4 : SPV_Vec4; // TODO(antiagainst): Use a more appropriate way to model optional operands class SPV_Optional : Variadic; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td @@ -61,7 +61,7 @@ ); let results = (outs - SPV_I32Vec4:$result + SPV_Int32Vec4:$result ); let verifier = [{ return success(); }]; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td @@ -95,7 +95,7 @@ ); let results = (outs - SPV_IntVec4:$result + SPV_IOrUIVec4:$result ); let assemblyFormat = [{ diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -363,11 +363,12 @@ // TODO: Make sure not caller relies on the actual pointer width value. return 64; } - if (type.isSignlessIntOrFloat()) { + + if (type.isIntOrFloat()) return type.getIntOrFloatBitWidth(); - } + if (auto vectorType = type.dyn_cast()) { - assert(vectorType.getElementType().isSignlessIntOrFloat()); + assert(vectorType.getElementType().isIntOrFloat()); return vectorType.getNumElements() * vectorType.getElementType().getIntOrFloatBitWidth(); } @@ -500,7 +501,7 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) { auto ptrType = op->getOperand(0).getType().cast(); auto elementType = ptrType.getPointeeType(); - if (!elementType.isSignlessInteger()) + if (!elementType.isa()) return op->emitOpError( "pointer operand must point to an integer value, found ") << elementType; @@ -1265,7 +1266,7 @@ numElements *= t.getNumElements(); opElemType = t.getElementType(); } - if (!opElemType.isSignlessIntOrFloat()) { + if (!opElemType.isIntOrFloat()) { return constOp.emitOpError("only support nested array result type"); } @@ -1769,8 +1770,6 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) { - // TODO(antiagainst): check the result integer type's signedness bit is 0. - spirv::Scope scope = ballotOp.execution_scope(); if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) return ballotOp.emitOpError( diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -344,9 +344,6 @@ /// insertion point. LogicalResult processUndef(ArrayRef operands); - /// Processes an OpBitcast instruction. - LogicalResult processBitcast(ArrayRef words); - /// Method to dispatch to the specialized deserialization function for an /// operation in SPIR-V dialect that is a mirror of an instruction in the /// SPIR-V spec. This is auto-generated from ODS. Dispatch is handled for @@ -1045,30 +1042,35 @@ switch (opcode) { case spirv::Opcode::OpTypeVoid: - if (operands.size() != 1) { + if (operands.size() != 1) return emitError(unknownLoc, "OpTypeVoid must have no parameters"); - } typeMap[operands[0]] = opBuilder.getNoneType(); break; case spirv::Opcode::OpTypeBool: - if (operands.size() != 1) { + if (operands.size() != 1) return emitError(unknownLoc, "OpTypeBool must have no parameters"); - } typeMap[operands[0]] = opBuilder.getI1Type(); break; - case spirv::Opcode::OpTypeInt: - if (operands.size() != 3) { + case spirv::Opcode::OpTypeInt: { + if (operands.size() != 3) return emitError( unknownLoc, "OpTypeInt must have bitwidth and signedness parameters"); - } - // TODO: Ignoring the signedness right now. Need to handle this effectively - // in the MLIR representation. - typeMap[operands[0]] = opBuilder.getIntegerType(operands[1]); - break; + + // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics + // to preserve or validate. + // 0 indicates unsigned, or no signedness semantics + // 1 indicates signed semantics." + // + // So we cannot differentiate signless and unsigned integers; always use + // signless semantics for such cases. + auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed + : IntegerType::SignednessSemantics::Signless; + typeMap[operands[0]] = IntegerType::get(operands[1], sign, context); + } break; case spirv::Opcode::OpTypeFloat: { - if (operands.size() != 2) { + if (operands.size() != 2) return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter"); - } + Type floatTy; switch (operands[1]) { case 16: @@ -1146,7 +1148,7 @@ } if (auto intVal = countInfo->first.dyn_cast()) { - count = intVal.getInt(); + count = intVal.getValue().getZExtValue(); } else { return emitError(unknownLoc, "OpTypeArray count must come from a " "scalar integer constant instruction"); @@ -1451,8 +1453,7 @@ } auto resultID = operands[1]; - if (resultType.isSignlessInteger() || resultType.isa() || - resultType.isa()) { + if (resultType.isIntOrFloat() || resultType.isa()) { auto attr = opBuilder.getZeroAttr(resultType); // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. @@ -2051,8 +2052,6 @@ // First dispatch all the instructions whose opcode does not correspond to // those that have a direct mirror in the SPIR-V dialect switch (opcode) { - case spirv::Opcode::OpBitcast: - return processBitcast(operands); case spirv::Opcode::OpCapability: return processCapability(operands); case spirv::Opcode::OpExtension: @@ -2152,76 +2151,6 @@ return success(); } -// TODO(b/130356985): This method is copied from the auto-generated -// deserialization function for OpBitcast instruction. This is to avoid -// generating a Bitcast operations for cast from signed integer to unsigned -// integer and viceversa. MLIR doesn't have native support for this so they both -// end up mapping to the same type right now which is illegal according to -// OpBitcast semantics (and enforced by the SPIR-V dialect). -LogicalResult Deserializer::processBitcast(ArrayRef words) { - SmallVector resultTypes; - size_t wordIndex = 0; - (void)wordIndex; - uint32_t valueID = 0; - (void)valueID; - { - if (wordIndex >= words.size()) { - return emitError( - unknownLoc, - "expected result type while deserializing spirv::BitcastOp"); - } - auto ty = getType(words[wordIndex]); - if (!ty) { - return emitError(unknownLoc, "unknown type result : ") - << words[wordIndex]; - } - resultTypes.push_back(ty); - wordIndex++; - if (wordIndex >= words.size()) { - return emitError( - unknownLoc, - "expected result while deserializing spirv::BitcastOp"); - } - } - valueID = words[wordIndex++]; - SmallVector operands; - SmallVector attributes; - if (wordIndex < words.size()) { - auto arg = getValue(words[wordIndex]); - if (!arg) { - return emitError(unknownLoc, "unknown result : ") - << words[wordIndex]; - } - operands.push_back(arg); - wordIndex++; - } - if (wordIndex != words.size()) { - return emitError(unknownLoc, - "found more operands than expected when deserializing " - "spirv::BitcastOp, only ") - << wordIndex << " of " << words.size() << " processed"; - } - if (resultTypes[0] == operands[0].getType() && - resultTypes[0].isSignlessInteger()) { - // TODO(b/130356985): This check is added to ignore error in Op verification - // due to both signed and unsigned integers mapping to the same - // type. Without this check this method is same as what is auto-generated. - valueMap[valueID] = operands[0]; - return success(); - } - - auto op = opBuilder.create(unknownLoc, resultTypes, - operands, attributes); - (void)op; - valueMap[valueID] = op.getResult(); - - if (decorations.count(valueID)) { - auto attrs = decorations[valueID].getAttrs(); - attributes.append(attrs.begin(), attrs.end()); - } - return success(); -} - LogicalResult Deserializer::processExtInst(ArrayRef operands) { if (operands.size() < 4) { return emitError(unknownLoc, diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -932,8 +932,11 @@ typeEnum = spirv::Opcode::OpTypeInt; operands.push_back(intType.getWidth()); - // TODO(antiagainst): support unsigned integers - operands.push_back(1); + // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics + // to preserve or validate. + // 0 indicates unsigned, or no signedness semantics + // 1 indicates signed semantics." + operands.push_back(intType.isSigned() ? 1 : 0); return success(); } diff --git a/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir @@ -4,6 +4,10 @@ spv.func @bit_cast(%arg0 : f32) "None" { // CHECK: {{%.*}} = spv.Bitcast {{%.*}} : f32 to i32 %0 = spv.Bitcast %arg0 : f32 to i32 + // CHECK: {{%.*}} = spv.Bitcast {{%.*}} : i32 to si32 + %1 = spv.Bitcast %0 : i32 to si32 + // CHECK: {{%.*}} = spv.Bitcast {{%.*}} : si32 to i32 + %2 = spv.Bitcast %1 : si32 to ui32 spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/Serialization/constant.mlir b/mlir/test/Dialect/SPIRV/Serialization/constant.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/constant.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/constant.mlir @@ -27,6 +27,37 @@ spv.Return } + // CHECK-LABEL: @si32_const + spv.func @si32_const() -> () "None" { + // CHECK: spv.constant 0 : si32 + %0 = spv.constant 0 : si32 + // CHECK: spv.constant 10 : si32 + %1 = spv.constant 10 : si32 + // CHECK: spv.constant -5 : si32 + %2 = spv.constant -5 : si32 + + %3 = spv.IAdd %0, %1 : si32 + %4 = spv.IAdd %2, %3 : si32 + spv.Return + } + + // CHECK-LABEL: @ui32_const + // We cannot differentiate signless vs. unsigned integers in SPIR-V blob + // because they all use 1 as the signedness bit. So we always treat them + // as signless integers. + spv.func @ui32_const() -> () "None" { + // CHECK: spv.constant 0 : i32 + %0 = spv.constant 0 : ui32 + // CHECK: spv.constant 10 : i32 + %1 = spv.constant 10 : ui32 + // CHECK: spv.constant -5 : i32 + %2 = spv.constant 4294967291 : ui32 + + %3 = spv.IAdd %0, %1 : ui32 + %4 = spv.IAdd %2, %3 : ui32 + spv.Return + } + // CHECK-LABEL: @i64_const spv.func @i64_const() -> () "None" { // CHECK: spv.constant 4294967296 : i64 @@ -141,8 +172,23 @@ spv.Return } - // CHECK-LABEL: @array_const - spv.func @array_const() -> (!spv.array<2 x vector<2xf32>>) "None" { + // CHECK-LABEL: @ui64_array_const + spv.func @ui64_array_const() -> (!spv.array<3xui64>) "None" { + // CHECK: spv.constant [5, 6, 7] : !spv.array<3 x i64> + %0 = spv.constant [5 : ui64, 6 : ui64, 7 : ui64] : !spv.array<3 x ui64> + + spv.ReturnValue %0: !spv.array<3xui64> + } + + // CHECK-LABEL: @si32_array_const + spv.func @si32_array_const() -> (!spv.array<3xsi32>) "None" { + // CHECK: spv.constant [5 : si32, 6 : si32, 7 : si32] : !spv.array<3 x si32> + %0 = spv.constant [5 : si32, 6 : si32, 7 : si32] : !spv.array<3 x si32> + + spv.ReturnValue %0 : !spv.array<3xsi32> + } + // CHECK-LABEL: @float_array_const + spv.func @float_array_const() -> (!spv.array<2 x vector<2xf32>>) "None" { // CHECK: spv.constant [dense<3.000000e+00> : vector<2xf32>, dense<[4.000000e+00, 5.000000e+00]> : vector<2xf32>] : !spv.array<2 x vector<2xf32>> %0 = spv.constant [dense<3.0> : vector<2xf32>, dense<[4., 5.]> : vector<2xf32>] : !spv.array<2 x vector<2xf32>> diff --git a/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir --- a/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir @@ -20,6 +20,14 @@ // ----- +func @group_non_uniform_ballot(%predicate: i1) -> vector<4xsi32> { + // expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit signless/unsigned integer values of length 4, but got 'vector<4xsi32>'}} + %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xsi32> + return %0: vector<4xsi32> +} + +// ----- + //===----------------------------------------------------------------------===// // spv.GroupNonUniformElect //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -752,7 +752,7 @@ func @logicalUnary(%arg0 : i32) { - // expected-error @+1 {{operand #0 must be 1-bit signless integer or vector of 1-bit signless integer values of length 2/3/4, but got 'i32'}} + // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4, but got 'i32'}} %0 = spv.LogicalNot %arg0 : i32 return }