diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -549,14 +549,17 @@ static const Capability caps[] = {Capability::cap8}; \ ArrayRef ref(caps, llvm::array_lengthof(caps)); \ capabilities.push_back(ref); \ - } else if (bitwidth == 16) { \ + return; \ + } \ + if (bitwidth == 16) { \ static const Capability caps[] = {Capability::cap16}; \ ArrayRef ref(caps, llvm::array_lengthof(caps)); \ capabilities.push_back(ref); \ + return; \ } \ - /* No requirements for other bitwidths */ \ - return; \ - } + /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \ + /* storage classes. Fall through to the next section. */ \ + } break // This part only handles the cases where special bitwidths appearing in // interface storage classes. @@ -573,8 +576,9 @@ static const Capability caps[] = {Capability::StorageInputOutput16}; ArrayRef ref(caps, llvm::array_lengthof(caps)); capabilities.push_back(ref); + return; } - return; + break; } default: break; @@ -594,22 +598,22 @@ if (auto intType = dyn_cast()) { switch (bitwidth) { - case 32: - case 1: - break; WIDTH_CASE(Int, 8); WIDTH_CASE(Int, 16); WIDTH_CASE(Int, 64); + case 1: + case 32: + break; default: llvm_unreachable("invalid bitwidth to getCapabilities"); } } else { assert(isa()); switch (bitwidth) { - case 32: - break; WIDTH_CASE(Float, 16); WIDTH_CASE(Float, 64); + case 32: + break; default: llvm_unreachable("invalid bitwidth to getCapabilities"); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -413,7 +413,7 @@ } int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8; - auto arrayElemCount = (memrefSize + *arrayElemSize - 1) / *arrayElemSize; + auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize); int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); @@ -455,13 +455,6 @@ if (!arrayElemType) return nullptr; - Optional elementSize = getTypeNumBytes(options, elementType); - if (!elementSize) { - LLVM_DEBUG(llvm::dbgs() - << type << " illegal: cannot deduce element size\n"); - return nullptr; - } - Optional arrayElemSize = getTypeNumBytes(options, arrayElemType); if (!arrayElemSize) { LLVM_DEBUG(llvm::dbgs() @@ -482,7 +475,7 @@ return nullptr; } - auto arrayElemCount = *memrefSize / *elementSize; + auto arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize); int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -312,53 +312,83 @@ func.func @memref_1bit_type(%arg0: memref<5xi1>) { return } // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer -// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> // NOEMU-LABEL: func @memref_8bit_StorageBuffer // NOEMU-SAME: memref<16xi8> func.func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return } // CHECK-LABEL: spv.func @memref_8bit_Uniform -// CHECK-SAME: !spv.ptr [0])>, Uniform> +// CHECK-SAME: !spv.ptr [0])>, Uniform> // NOEMU-LABEL: func @memref_8bit_Uniform // NOEMU-SAME: memref<16xsi8, 4> func.func @memref_8bit_Uniform(%arg0: memref<16xsi8, 4>) { return } // CHECK-LABEL: spv.func @memref_8bit_PushConstant -// CHECK-SAME: !spv.ptr [0])>, PushConstant> +// CHECK-SAME: !spv.ptr [0])>, PushConstant> // NOEMU-LABEL: func @memref_8bit_PushConstant // NOEMU-SAME: memref<16xui8, 7> func.func @memref_8bit_PushConstant(%arg0: memref<16xui8, 7>) { return } // CHECK-LABEL: spv.func @memref_16bit_StorageBuffer -// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> // NOEMU-LABEL: func @memref_16bit_StorageBuffer // NOEMU-SAME: memref<16xi16> func.func @memref_16bit_StorageBuffer(%arg0: memref<16xi16, 0>) { return } // CHECK-LABEL: spv.func @memref_16bit_Uniform -// CHECK-SAME: !spv.ptr [0])>, Uniform> +// CHECK-SAME: !spv.ptr [0])>, Uniform> // NOEMU-LABEL: func @memref_16bit_Uniform // NOEMU-SAME: memref<16xsi16, 4> func.func @memref_16bit_Uniform(%arg0: memref<16xsi16, 4>) { return } // CHECK-LABEL: spv.func @memref_16bit_PushConstant -// CHECK-SAME: !spv.ptr [0])>, PushConstant> +// CHECK-SAME: !spv.ptr [0])>, PushConstant> // NOEMU-LABEL: func @memref_16bit_PushConstant // NOEMU-SAME: memref<16xui16, 7> func.func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return } // CHECK-LABEL: spv.func @memref_16bit_Input -// CHECK-SAME: !spv.ptr)>, Input> +// CHECK-SAME: !spv.ptr)>, Input> // NOEMU-LABEL: func @memref_16bit_Input // NOEMU-SAME: memref<16xf16, 9> func.func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return } // CHECK-LABEL: spv.func @memref_16bit_Output -// CHECK-SAME: !spv.ptr)>, Output> +// CHECK-SAME: !spv.ptr)>, Output> // NOEMU-LABEL: func @memref_16bit_Output // NOEMU-SAME: memref<16xf16, 10> func.func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return } +// CHECK-LABEL: spv.func @memref_64bit_StorageBuffer +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> +// NOEMU-LABEL: func @memref_64bit_StorageBuffer +// NOEMU-SAME: memref<16xi64> +func.func @memref_64bit_StorageBuffer(%arg0: memref<16xi64, 0>) { return } + +// CHECK-LABEL: spv.func @memref_64bit_Uniform +// CHECK-SAME: !spv.ptr [0])>, Uniform> +// NOEMU-LABEL: func @memref_64bit_Uniform +// NOEMU-SAME: memref<16xsi64, 4> +func.func @memref_64bit_Uniform(%arg0: memref<16xsi64, 4>) { return } + +// CHECK-LABEL: spv.func @memref_64bit_PushConstant +// CHECK-SAME: !spv.ptr [0])>, PushConstant> +// NOEMU-LABEL: func @memref_64bit_PushConstant +// NOEMU-SAME: memref<16xui64, 7> +func.func @memref_64bit_PushConstant(%arg0: memref<16xui64, 7>) { return } + +// CHECK-LABEL: spv.func @memref_64bit_Input +// CHECK-SAME: !spv.ptr)>, Input> +// NOEMU-LABEL: func @memref_64bit_Input +// NOEMU-SAME: memref<16xf64, 9> +func.func @memref_64bit_Input(%arg3: memref<16xf64, 9>) { return } + +// CHECK-LABEL: spv.func @memref_64bit_Output +// CHECK-SAME: !spv.ptr)>, Output> +// NOEMU-LABEL: func @memref_64bit_Output +// NOEMU-SAME: memref<16xf64, 10> +func.func @memref_64bit_Output(%arg4: memref<16xf64, 10>) { return } + } // end module // ----- @@ -368,7 +398,7 @@ // and extension is available. module attributes { spv.target_env = #spv.target_env< - #spv.vce, {}> } { @@ -389,6 +419,17 @@ %arg1: memref<16xf16, 7> ) { return } +// CHECK-LABEL: spv.func @memref_64bit_PushConstant +// CHECK-SAME: !spv.ptr [0])>, PushConstant> +// CHECK-SAME: !spv.ptr [0])>, PushConstant> +// NOEMU-LABEL: spv.func @memref_64bit_PushConstant +// NOEMU-SAME: !spv.ptr [0])>, PushConstant> +// NOEMU-SAME: !spv.ptr [0])>, PushConstant> +func.func @memref_64bit_PushConstant( + %arg0: memref<16xi64, 7>, + %arg1: memref<16xf64, 7> +) { return } + } // end module // ----- @@ -398,7 +439,7 @@ // and extension is available. module attributes { spv.target_env = #spv.target_env< - #spv.vce, {}> } { @@ -419,6 +460,17 @@ %arg1: memref<16xf16, 0> ) { return } +// CHECK-LABEL: spv.func @memref_64bit_StorageBuffer +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> +// NOEMU-LABEL: spv.func @memref_64bit_StorageBuffer +// NOEMU-SAME: !spv.ptr [0])>, StorageBuffer> +// NOEMU-SAME: !spv.ptr [0])>, StorageBuffer> +func.func @memref_64bit_StorageBuffer( + %arg0: memref<16xi64, 0>, + %arg1: memref<16xf64, 0> +) { return } + } // end module // ----- @@ -428,7 +480,7 @@ // and extension is available. module attributes { spv.target_env = #spv.target_env< - #spv.vce, {}> } { @@ -449,6 +501,17 @@ %arg1: memref<16xf16, 4> ) { return } +// CHECK-LABEL: spv.func @memref_64bit_Uniform +// CHECK-SAME: !spv.ptr [0])>, Uniform> +// CHECK-SAME: !spv.ptr [0])>, Uniform> +// NOEMU-LABEL: spv.func @memref_64bit_Uniform +// NOEMU-SAME: !spv.ptr [0])>, Uniform> +// NOEMU-SAME: !spv.ptr [0])>, Uniform> +func.func @memref_64bit_Uniform( + %arg0: memref<16xi64, 4>, + %arg1: memref<16xf64, 4> +) { return } + } // end module // ----- @@ -458,7 +521,7 @@ // and extension is available. module attributes { spv.target_env = #spv.target_env< - #spv.vce, {}> + #spv.vce, {}> } { // CHECK-LABEL: spv.func @memref_16bit_Input @@ -473,6 +536,28 @@ // NOEMU-SAME: !spv.ptr)>, Output> func.func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return } +// CHECK-LABEL: spv.func @memref_64bit_Input +// CHECK-SAME: !spv.ptr)>, Input> +// CHECK-SAME: !spv.ptr)>, Input> +// NOEMU-LABEL: spv.func @memref_64bit_Input +// NOEMU-SAME: !spv.ptr)>, Input> +// NOEMU-SAME: !spv.ptr)>, Input> +func.func @memref_64bit_Input( + %arg0: memref<16xi64, 9>, + %arg1: memref<16xf64, 9> +) { return } + +// CHECK-LABEL: spv.func @memref_64bit_Output +// CHECK-SAME: !spv.ptr)>, Output> +// CHECK-SAME: !spv.ptr)>, Output> +// NOEMU-LABEL: spv.func @memref_64bit_Output +// NOEMU-SAME: !spv.ptr)>, Output> +// NOEMU-SAME: !spv.ptr)>, Output> +func.func @memref_64bit_Output( + %arg0: memref<16xi64, 10>, + %arg1: memref<16xf64, 10> +) { return } + } // end module // ----- diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir @@ -40,7 +40,7 @@ } // CHECK: spv.GlobalVariable @__workgroup_mem__{{[0-9]+}} -// CHECK-SAME: !spv.ptr)>, Workgroup> +// CHECK-SAME: !spv.ptr)>, Workgroup> // CHECK: func @alloc_dealloc_workgroup_mem // CHECK: %[[VAR:.+]] = spv.mlir.addressof @__workgroup_mem__0 // CHECK: %[[LOC:.+]] = spv.SDiv diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -186,10 +186,10 @@ // Complicated nested types // * Buffer requires ImageBuffer or SampledBuffer. // * Rg32f requires StorageImageExtendedFormats. -// CHECK: requires #spv.vce +// CHECK: requires #spv.vce spv.module Logical GLSL450 attributes { spv.target_env = #spv.target_env< - #spv.vce, + #spv.vce, {}> } { spv.GlobalVariable @data : !spv.ptr, Uniform>