diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -26,10 +26,19 @@ // Type Converter //===----------------------------------------------------------------------===// +/// How sub-byte values are storaged in memory. +enum class SPIRVSubByteTypeStorage { + /// Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8. + Packing, +}; + struct SPIRVConversionOptions { /// The number of bits to store a boolean value. unsigned boolNumBits{8}; + /// How sub-byte values are storaged in memory. + SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packing}; + /// Whether to emulate narrower scalar types with 32-bit scalar types if not /// supported by the target. /// diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -20,6 +20,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" #include #include @@ -256,6 +257,31 @@ intType.getSignedness()); } +/// Converts a sub-byte integer `type` to i32 regardless of target environment. +/// +/// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use +/// the above given that these sub-byte types are not supported at all in +/// SPIR-V; there are no compute/storage capability for them like other +/// supported integer types. +static Type convertSubByteIntegerType(const SPIRVConversionOptions &options, + IntegerType type) { + if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packing) { + LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n"); + return nullptr; + } + + if (!llvm::isPowerOf2_32(type.getWidth())) { + LLVM_DEBUG(llvm::dbgs() + << "unsupported non-power-of-two bitwidth in sub-byte" << type + << "\n"); + return nullptr; + } + + LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); + return IntegerType::get(type.getContext(), /*width=*/32, + type.getSignedness()); +} + /// Returns a type with the same shape but with any index element type converted /// to the matching integer type. This is a noop when the element type is not /// the index type. @@ -417,7 +443,41 @@ return wrapInStructAndGetPointer(arrayType, storageClass); } - int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8; + int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8); + auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize); + int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; + auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); + if (targetEnv.allows(spirv::Capability::Kernel)) + return spirv::PointerType::get(arrayType, storageClass); + return wrapInStructAndGetPointer(arrayType, storageClass); +} + +static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv, + const SPIRVConversionOptions &options, + MemRefType type, + spirv::StorageClass storageClass) { + IntegerType elementType = cast(type.getElementType()); + Type arrayElemType = convertSubByteIntegerType(options, elementType); + if (!arrayElemType) + return nullptr; + std::optional arrayElemSize = + getTypeNumBytes(options, arrayElemType); + assert(arrayElemSize); + + if (!type.hasStaticShape()) { + // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing + // to the element. + if (targetEnv.allows(spirv::Capability::Kernel)) + return spirv::PointerType::get(arrayElemType, storageClass); + int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; + auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); + // For Vulkan we need extra wrapping struct and array to satisfy interface + // needs. + return wrapInStructAndGetPointer(arrayType, storageClass); + } + + int64_t memrefSize = + llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8); auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize); int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); @@ -441,9 +501,11 @@ } spirv::StorageClass storageClass = attr.getValue(); - if (type.getElementType().isa() && - type.getElementTypeBitWidth() == 1) { - return convertBoolMemrefType(targetEnv, options, type, storageClass); + if (type.getElementType().isa()) { + if (type.getElementTypeBitWidth() == 1) + return convertBoolMemrefType(targetEnv, options, type, storageClass); + if (type.getElementTypeBitWidth() < 8) + return convertSubByteMemrefType(targetEnv, options, type, storageClass); } Type arrayElemType; @@ -514,10 +576,10 @@ // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType) // were tried before. // - // TODO: this assumes that the SPIR-V types are valid to use in - // the given target environment, which should be the case if the whole - // pipeline is driven by the same target environment. Still, we probably still - // want to validate and convert to be safe. + // TODO: This assumes that the SPIR-V types are valid to use in the given + // target environment, which should be the case if the whole pipeline is + // driven by the same target environment. Still, we probably still want to + // validate and convert to be safe. addConversion([](spirv::SPIRVType type) { return type; }); addConversion([this](IndexType /*indexType*/) { return getIndexType(); }); @@ -525,6 +587,8 @@ addConversion([this](IntegerType intType) -> std::optional { if (auto scalarType = intType.dyn_cast()) return convertScalarType(this->targetEnv, this->options, scalarType); + if (intType.getWidth() < 8) + return convertSubByteIntegerType(this->options, intType); return Type(); }); diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -94,14 +94,32 @@ // ----- -// Check that weird bitwidths are not supported. +// Check that power-of-two sub-byte bitwidths are converted to i32. module attributes { spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { -// CHECK-NOT: spirv.func @integer4 +// CHECK: spirv.func @integer2(%{{.+}}: i32) +func.func @integer2(%arg0: i8) { return } + +// CHECK: spirv.func @integer4(%{{.+}}: i32) func.func @integer4(%arg0: i4) { return } +} // end module + +// ----- + +// Check that other bitwidths are not supported. +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-NOT: spirv.func @integer3 +func.func @integer3(%arg0: i3) { return } + +// CHECK-NOT: spirv.func @integer13 +func.func @integer4(%arg0: i13) { return } + // CHECK-NOT: spirv.func @integer128 func.func @integer128(%arg0: i128) { return } @@ -109,6 +127,7 @@ func.func @integer42(%arg0: i42) { return } } // end module + // ----- //===----------------------------------------------------------------------===// @@ -421,6 +440,16 @@ // NOEMU-SAME: memref<5xi1, #spirv.storage_class> func.func @memref_1bit_type(%arg0: memref<5xi1, #spirv.storage_class>) { return } +// 16 i2 values are tightly packed into one i32 value; so 33 i2 values takes 3 i32 value. +// CHECK-LABEL: spirv.func @memref_2bit_type +// CHECK-SAME: !spirv.ptr [0])>, StorageBuffer> +func.func @memref_2bit_type(%arg0: memref<33xi2, #spirv.storage_class>) { return } + +// 8 i4 values are tightly packed into one i32 value; so 16 i4 values takes 2 i32 value. +// CHECK-LABEL: spirv.func @memref_4bit_type +// CHECK-SAME: !spirv.ptr [0])>, StorageBuffer> +func.func @memref_4bit_type(%arg0: memref<16xi4, #spirv.storage_class>) { return } + // CHECK-LABEL: spirv.func @memref_8bit_StorageBuffer // CHECK-SAME: !spirv.ptr [0])>, StorageBuffer> // NOEMU-LABEL: func @memref_8bit_StorageBuffer @@ -725,6 +754,14 @@ // NOEMU-SAME: memref> func.func @memref_1bit_type(%arg0: memref>) { return } +// CHECK-LABEL: spirv.func @memref_2bit_type +// CHECK-SAME: !spirv.ptr [0])>, StorageBuffer> +func.func @memref_2bit_type(%arg0: memref>) { return } + +// CHECK-LABEL: spirv.func @memref_4bit_type +// CHECK-SAME: !spirv.ptr [0])>, StorageBuffer> +func.func @memref_4bit_type(%arg0: memref>) { return } + // CHECK-LABEL: func @dynamic_dim_memref // CHECK-SAME: !spirv.ptr [0])>, StorageBuffer> // CHECK-SAME: !spirv.ptr [0])>, StorageBuffer> diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -419,3 +419,63 @@ } } // end module + +// ----- + +// Check that access chain indices are properly adjusted if sub-byte types are +// emulated via 32-bit types. +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-LABEL: @load_i4 +func.func @load_i4(%arg0: memref>, %i: index) -> i4 { + // CHECK: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %{{.+}} : index to i32 + // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 + // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 + // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[INDEX]] : i32 + // CHECK: %[[OFFSET:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32 + // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32 + // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[OFFSET]], %[[EIGHT]] : i32 + // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32 + // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32 + // CHECK: %[[IDX:.+]] = spirv.UMod %[[OFFSET]], %[[EIGHT]] : i32 + // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[FOUR]] : i32 + // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 + // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32 + // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[C28:.+]] = spirv.Constant 28 : i32 + // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[AND]], %[[C28]] : i32, i32 + // CHECK: spirv.ShiftRightArithmetic %[[SL]], %[[C28]] : i32, i32 + %0 = memref.load %arg0[%i] : memref> + return %0 : i4 +} + +// CHECK-LABEL: @store_i4 +func.func @store_i4(%arg0: memref>, %value: i4, %i: index) { + // CHECK: %[[VAL:.+]] = builtin.unrealized_conversion_cast %{{.+}} : i4 to i32 + // CHECK: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %{{.+}} : index to i32 + // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 + // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 + // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[INDEX]] : i32 + // CHECK: %[[OFFSET:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32 + // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32 + // CHECK: %[[FOUR:.+]] = spirv.Constant [[OFFSET]] : i32 + // CHECK: %[[IDX:.+]] = spirv.UMod %[[OFFSET]], %[[EIGHT]] : i32 + // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[FOUR]] : i32 + // CHECK: %[[MASK1:.+]] = spirv.Constant 15 : i32 + // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[BITS]] : i32, i32 + // CHECK: %[[MASK2:.+]] = spirv.Not %[[SL]] : i32 + // CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[VAL]], %[[MASK1]] : i32 + // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[BITS]] : i32, i32 + // CHECK: %[[ACCESS_INDEX:.+]] = spirv.SDiv %[[OFFSET]], %[[EIGHT]] : i32 + // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ACCESS_INDEX]]] + // CHECK: spirv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK2]] + // CHECK: spirv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] + memref.store %value, %arg0[%i] : memref> + return +} + +} // end module