diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -69,6 +69,9 @@ if (!elementSize) { return llvm::None; } + if (memRefType.getRank() == 0) { + return elementSize; + } auto dims = memRefType.getShape(); if (llvm::is_contained(dims, ShapedType::kDynamicSize) || offset == MemRefType::getDynamicStrideOrOffset() || @@ -332,8 +335,12 @@ } SmallVector linearizedIndices; // Add a '0' at the start to index into the struct. - linearizedIndices.push_back(builder.create( - loc, indexType, IntegerAttr::get(indexType, 0))); + auto zero = spirv::ConstantOp::getZero(indexType, loc, &builder); + linearizedIndices.push_back(zero); + // If it is a zero-rank memref type, extract the element directly. + if (!ptrLoc) { + ptrLoc = zero; + } linearizedIndices.push_back(ptrLoc); return builder.create(loc, basePtr, linearizedIndices); } diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -312,3 +312,41 @@ func @memref_type(%arg0: memref<3xi1>) { return } + +// CHECK-LABEL: @load_store_zero_rank_float +// CHECK: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, +// CHECK: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer>) +func @load_store_zero_rank_float(%arg0: memref, %arg1: memref) { + // CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32 + // CHECK: spv.AccessChain [[ARG0]][ + // CHECK-SAME: [[ZERO1]], [[ZERO1]] + // CHECK-SAME: ] : + // CHECK: spv.Load "StorageBuffer" %{{.*}} : f32 + %0 = load %arg0[] : memref + // CHECK: [[ZERO2:%.*]] = spv.constant 0 : i32 + // CHECK: spv.AccessChain [[ARG1]][ + // CHECK-SAME: [[ZERO2]], [[ZERO2]] + // CHECK-SAME: ] : + // CHECK: spv.Store "StorageBuffer" %{{.*}} : f32 + store %0, %arg1[] : memref + return +} + +// CHECK-LABEL: @load_store_zero_rank_int +// CHECK: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, +// CHECK: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer>) +func @load_store_zero_rank_int(%arg0: memref, %arg1: memref) { + // CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32 + // CHECK: spv.AccessChain [[ARG0]][ + // CHECK-SAME: [[ZERO1]], [[ZERO1]] + // CHECK-SAME: ] : + // CHECK: spv.Load "StorageBuffer" %{{.*}} : i32 + %0 = load %arg0[] : memref + // CHECK: [[ZERO2:%.*]] = spv.constant 0 : i32 + // CHECK: spv.AccessChain [[ARG1]][ + // CHECK-SAME: [[ZERO2]], [[ZERO2]] + // CHECK-SAME: ] : + // CHECK: spv.Store "StorageBuffer" %{{.*}} : i32 + store %0, %arg1[] : memref + return +}