diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -994,13 +994,15 @@ bool isBool = srcBits == 1; if (isBool) srcBits = typeConverter.getOptions().boolNumBits; - auto dstType = typeConverter.convertType(memrefType) + Type dstType = typeConverter.convertType(memrefType) .cast() - .getPointeeType() - .cast() - .getElementType(0) - .cast() - .getElementType(); + .getPointeeType(); + dstType = dstType.cast().getElementType(0); + if (auto arrayType = dstType.dyn_cast()) + dstType = arrayType.getElementType(); + else + dstType = dstType.cast().getElementType(); + int dstBits = dstType.getIntOrFloatBitWidth(); assert(dstBits % srcBits == 0); @@ -1136,13 +1138,15 @@ bool isBool = srcBits == 1; if (isBool) srcBits = typeConverter.getOptions().boolNumBits; - auto dstType = typeConverter.convertType(memrefType) + Type dstType = typeConverter.convertType(memrefType) .cast() - .getPointeeType() - .cast() - .getElementType(0) - .cast() - .getElementType(); + .getPointeeType(); + dstType = dstType.cast().getElementType(0); + if (auto arrayType = dstType.dyn_cast()) + dstType = arrayType.getElementType(); + else + dstType = dstType.cast().getElementType(); + int dstBits = dstType.getIntOrFloatBitWidth(); assert(dstBits % srcBits == 0); diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -905,6 +905,19 @@ return } +// CHECK-LABEL: func @load_store_unknown_dim +// CHECK-SAME: %[[SRC:[a-z0-9]+]]: !spv.ptr [0])>, StorageBuffer>, +// CHECK-SAME: %[[DST:[a-z0-9]+]]: !spv.ptr [0])>, StorageBuffer>) +func @load_store_unknown_dim(%i: index, %source: memref, %dest: memref) { + // CHECK: %[[AC0:.+]] = spv.AccessChain %[[SRC]] + // CHECK: spv.Load "StorageBuffer" %[[AC0]] + %0 = memref.load %source[%i] : memref + // CHECK: %[[AC1:.+]] = spv.AccessChain %[[DST]] + // CHECK: spv.Store "StorageBuffer" %[[AC1]] + memref.store %0, %dest[%i]: memref + return +} + } // end module // -----