diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -217,11 +217,15 @@ /// Returns true if the allocations of type `t` can be lowered to SPIR-V. static bool isAllocationSupported(MemRefType t) { // Currently only support workgroup local memory allocations with static - // shape and int or float element type. - return t.hasStaticShape() && - SPIRVTypeConverter::getMemorySpaceForStorageClass( - spirv::StorageClass::Workgroup) == t.getMemorySpace() && - t.getElementType().isIntOrFloat(); + // shape and int or float or vector of int or float element type. + if (!(t.hasStaticShape() && + SPIRVTypeConverter::getMemorySpaceForStorageClass( + spirv::StorageClass::Workgroup) == t.getMemorySpace())) + return false; + Type elementType = t.getElementType(); + if (auto vecType = elementType.dyn_cast()) + elementType = vecType.getElementType(); + return elementType.isIntOrFloat(); } /// Returns the scope to use for atomic operations use for emulating store diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -170,7 +170,14 @@ return llvm::None; } return bitWidth / 8; - } else if (auto memRefType = t.dyn_cast()) { + } + if (auto vecType = t.dyn_cast()) { + auto elementSize = getTypeNumBytes(vecType.getElementType()); + if (!elementSize) + return llvm::None; + return vecType.getNumElements() * *elementSize; + } + if (auto memRefType = t.dyn_cast()) { // TODO: Layout should also be controlled by the ABI attributes. For now // using the layout from MemRef. int64_t offset; @@ -343,26 +350,31 @@ return llvm::None; } - auto scalarType = type.getElementType().dyn_cast(); - if (!scalarType) { - LLVM_DEBUG(llvm::dbgs() - << type << " illegal: cannot convert non-scalar element type\n"); + Optional arrayElemType; + Type elementType = type.getElementType(); + if (auto vecType = elementType.dyn_cast()) { + arrayElemType = convertVectorType(targetEnv, vecType, storageClass); + } else if (auto scalarType = elementType.dyn_cast()) { + arrayElemType = convertScalarType(targetEnv, scalarType, storageClass); + } else { + LLVM_DEBUG( + llvm::dbgs() + << type + << " unhandled: can only convert scalar or vector element type\n"); return llvm::None; } - - auto arrayElemType = convertScalarType(targetEnv, scalarType, storageClass); if (!arrayElemType) return llvm::None; - Optional scalarSize = getTypeNumBytes(scalarType); - if (!scalarSize) { + Optional elementSize = getTypeNumBytes(elementType); + if (!elementSize) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot deduce element size\n"); return llvm::None; } if (!type.hasStaticShape()) { - auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *scalarSize); + auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize); // Wrap in a struct to satisfy Vulkan interface requirements. auto structType = spirv::StructType::get(arrayType, 0); return spirv::PointerType::get(structType, *storageClass); @@ -375,7 +387,7 @@ return llvm::None; } - auto arrayElemCount = *memrefSize / *scalarSize; + auto arrayElemCount = *memrefSize / *elementSize; Optional arrayElemSize = getTypeNumBytes(*arrayElemType); if (!arrayElemSize) { diff --git a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir --- a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir @@ -75,6 +75,30 @@ // CHECK: spv.func @two_allocs() // CHECK: spv.Return +// ----- + +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> + } +{ + func @two_allocs_vector() { + %0 = alloc() : memref<4xvector<4xf32>, 3> + %1 = alloc() : memref<2xvector<2xi32>, 3> + return + } +} + +// CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}} +// CHECK-SAME: !spv.ptr, stride=8>>, Workgroup> +// CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}} +// CHECK-SAME: !spv.ptr, stride=16>>, Workgroup> +// CHECK: spv.func @two_allocs_vector() +// CHECK: spv.Return + + // ----- module attributes { diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir @@ -510,6 +510,51 @@ // ----- +// Vector types +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: func @memref_vector +// CHECK-SAME: !spv.ptr, stride=8> [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr, stride=16> [0]>, Uniform> +func @memref_vector( + %arg0: memref<4xvector<2xf32>, 0>, + %arg1: memref<4xvector<4xf32>, 4>) +{ return } + +// CHECK-LABEL: func @dynamic_dim_memref_vector +// CHECK-SAME: !spv.ptr, stride=16> [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr, stride=8> [0]>, StorageBuffer> +func @dynamic_dim_memref_vector(%arg0: memref<8x?xvector<4xi32>>, + %arg1: memref>) +{ return } + +} // end module + +// ----- + +// Vector types, check that sizes not available in SPIR-V are not transformed. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: func @memref_vector_wrong_size +// CHECK-SAME: memref<4xvector<5xf32>> +func @memref_vector_wrong_size( + %arg0: memref<4xvector<5xf32>, 0>) +{ return } + +} // end module + +// ----- + //===----------------------------------------------------------------------===// // Tensor types //===----------------------------------------------------------------------===//