diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -69,6 +69,9 @@ if (!elementSize) { return llvm::None; } + if (memRefType.getRank() == 0) { + return elementSize; + } auto dims = memRefType.getShape(); if (llvm::is_contained(dims, ShapedType::kDynamicSize) || offset == MemRefType::getDynamicStrideOrOffset() || @@ -334,7 +337,9 @@ // Add a '0' at the start to index into the struct. linearizedIndices.push_back(builder.create( loc, indexType, IntegerAttr::get(indexType, 0))); - linearizedIndices.push_back(ptrLoc); + if (ptrLoc) { + linearizedIndices.push_back(ptrLoc); + } return builder.create(loc, basePtr, linearizedIndices); } diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -312,3 +312,25 @@ func @memref_type(%arg0: memref<3xi1>) { return } + +// CHECK-LABEL: @load_store_zero_rank_float +// CHECK: !spv.ptr [0]>, StorageBuffer> +func @load_store_zero_rank_float(%arg0: memref, %arg1: memref) { + // CHECK: spv.AccessChain + // CHECK: spv.Load + // CHECK: spv.Store + %0 = load %arg0[] : memref + store %0, %arg1[] : memref + return +} + +// CHECK-LABEL: @load_store_zero_rank_int +// CHECK: !spv.ptr [0]>, StorageBuffer> +func @load_store_zero_rank_int(%arg0: memref, %arg1: memref) { + // CHECK: spv.AccessChain + // CHECK: spv.Load + // CHECK: spv.Store + %0 = load %arg0[] : memref + store %0, %arg1[] : memref + return +}