diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -331,10 +331,11 @@ static Optional convertMemrefType(const spirv::TargetEnv &targetEnv, MemRefType type) { - // TODO(ravishankarm) : Handle dynamic shapes. - if (!type.hasStaticShape()) { + auto storageClass = + SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace()); + if (!storageClass) { LLVM_DEBUG(llvm::dbgs() - << type << " illegal: dynamic shape unimplemented\n"); + << type << " illegal: cannot convert memory space\n"); return llvm::None; } @@ -345,27 +346,33 @@ return llvm::None; } + auto arrayElemType = convertScalarType(targetEnv, scalarType, storageClass); + if (!arrayElemType) + return llvm::None; + Optional scalarSize = getTypeNumBytes(scalarType); - Optional memrefSize = getTypeNumBytes(type); - if (!scalarSize || !memrefSize) { + if (!scalarSize) { LLVM_DEBUG(llvm::dbgs() - << type << " illegal: cannot deduce element count\n"); + << type << " illegal: cannot deduce element size\n"); return llvm::None; } - auto arrayElemCount = *memrefSize / *scalarSize; + if (!type.hasStaticShape()) { + auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *scalarSize); + // Wrap in a struct to satisfy Vulkan interface requirements. + auto structType = spirv::StructType::get(arrayType, 0); + return spirv::PointerType::get(structType, *storageClass); + } - auto storageClass = - SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace()); - if (!storageClass) { + Optional memrefSize = getTypeNumBytes(type); + if (!memrefSize) { LLVM_DEBUG(llvm::dbgs() - << type << " illegal: cannot convert memory space\n"); + << type << " illegal: cannot deduce element count\n"); return llvm::None; } - auto arrayElemType = convertScalarType(targetEnv, scalarType, storageClass); - if (!arrayElemType) - return llvm::None; + auto arrayElemCount = *memrefSize / *scalarSize; + Optional arrayElemSize = getTypeNumBytes(*arrayElemType); if (!arrayElemSize) { LLVM_DEBUG(llvm::dbgs() diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir rename from mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir rename to mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir @@ -486,7 +486,7 @@ // ----- -// Check that dynamic shapes are not supported. +// Dynamic shapes module attributes { spv.target_env = #spv.target_env< #spv.vce, @@ -494,13 +494,17 @@ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> } { +// Check that unranked shapes are not supported. // CHECK-LABEL: func @unranked_memref // CHECK-SAME: memref<*xi32> func @unranked_memref(%arg0: memref<*xi32>) { return } // CHECK-LABEL: func @dynamic_dim_memref -// CHECK-SAME: memref<8x?xi32> -func @dynamic_dim_memref(%arg0: memref<8x?xi32>) { return } +// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +func @dynamic_dim_memref(%arg0: memref<8x?xi32>, + %arg1: memref) +{ return } } // end module