diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -69,6 +69,9 @@ if (!elementSize) { return llvm::None; } + if (memRefType.getRank() == 0) { + return elementSize; + } auto dims = memRefType.getShape(); if (llvm::is_contained(dims, ShapedType::kDynamicSize) || offset == MemRefType::getDynamicStrideOrOffset() || @@ -325,8 +328,12 @@ } SmallVector linearizedIndices; // Add a '0' at the start to index into the struct. - linearizedIndices.push_back(builder.create( - loc, indexType, IntegerAttr::get(indexType, 0))); + auto zero = spirv::ConstantOp::getZero(indexType, loc, &builder); + linearizedIndices.push_back(zero); + // If it is a zero-rank memref type, extract the element directly. + if (!ptrLoc) { + ptrLoc = zero; + } linearizedIndices.push_back(ptrLoc); return builder.create(loc, basePtr, linearizedIndices); } diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -312,3 +312,41 @@ func @memref_type(%arg0: memref<3xi1>) { return } + +// CHECK-LABEL: @load_store_zero_rank_float +// CHECK: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, +// CHECK: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer>) +func @load_store_zero_rank_float(%arg0: memref, %arg1: memref) { + // CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32 + // CHECK: spv.AccessChain [[ARG0]][ + // CHECK-SAME: [[ZERO1]], [[ZERO1]] + // CHECK-SAME: ] : + // CHECK: spv.Load "StorageBuffer" %{{.*}} : f32 + %0 = load %arg0[] : memref + // CHECK: [[ZERO2:%.*]] = spv.constant 0 : i32 + // CHECK: spv.AccessChain [[ARG1]][ + // CHECK-SAME: [[ZERO2]], [[ZERO2]] + // CHECK-SAME: ] : + // CHECK: spv.Store "StorageBuffer" %{{.*}} : f32 + store %0, %arg1[] : memref + return +} + +// CHECK-LABEL: @load_store_zero_rank_int +// CHECK: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, +// CHECK: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer>) +func @load_store_zero_rank_int(%arg0: memref, %arg1: memref) { + // CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32 + // CHECK: spv.AccessChain [[ARG0]][ + // CHECK-SAME: [[ZERO1]], [[ZERO1]] + // CHECK-SAME: ] : + // CHECK: spv.Load "StorageBuffer" %{{.*}} : i32 + %0 = load %arg0[] : memref + // CHECK: [[ZERO2:%.*]] = spv.constant 0 : i32 + // CHECK: spv.AccessChain [[ARG1]][ + // CHECK-SAME: [[ZERO2]], [[ZERO2]] + // CHECK-SAME: ] : + // CHECK: spv.Store "StorageBuffer" %{{.*}} : i32 + store %0, %arg1[] : memref + return +} diff --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir @@ -23,3 +23,37 @@ spv.Return } } + +// ----- + +spv.module "Logical" "GLSL450" { + spv.func @load_store_zero_rank_float(%arg0: !spv.ptr [0]>, StorageBuffer>, %arg1: !spv.ptr [0]>, StorageBuffer>) "None" { + // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : f32 + %0 = spv.constant 0 : i32 + %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer> + %2 = spv.Load "StorageBuffer" %1 : f32 + + // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32 + %3 = spv.constant 0 : i32 + %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer> + spv.Store "StorageBuffer" %4, %2 : f32 + spv.Return + } + + spv.func @load_store_zero_rank_int(%arg0: !spv.ptr [0]>, StorageBuffer>, %arg1: !spv.ptr [0]>, StorageBuffer>) "None" { + // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : i32 + %0 = spv.constant 0 : i32 + %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer> + %2 = spv.Load "StorageBuffer" %1 : i32 + + // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32 + %3 = spv.constant 0 : i32 + %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer> + spv.Store "StorageBuffer" %4, %2 : i32 + spv.Return + } +}