diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -335,8 +335,10 @@ .getPointeeType(); Type dstType; if (typeConverter.allows(spirv::Capability::Kernel)) { - // For OpenCL Kernel, pointer will be directly pointing to the element. - dstType = pointeeType; + if (auto arrayType = pointeeType.dyn_cast()) + dstType = arrayType.getElementType(); + else + dstType = pointeeType; } else { // For Vulkan we need to extract element from wrapping struct and array. Type structElemType = @@ -464,8 +466,10 @@ .getPointeeType(); Type dstType; if (typeConverter.allows(spirv::Capability::Kernel)) { - // For OpenCL Kernel, pointer will be directly pointing to the element. - dstType = pointeeType; + if (auto arrayType = pointeeType.dyn_cast()) + dstType = arrayType.getElementType(); + else + dstType = pointeeType; } else { // For Vulkan we need to extract element from wrapping struct and array. Type structElemType = diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -338,15 +338,16 @@ return nullptr; } - // For OpenCL Kernel we can just emit a pointer pointing to the element. - if (targetEnv.allows(spirv::Capability::Kernel)) - return spirv::PointerType::get(arrayElemType, storageClass); - // For Vulkan we need extra wrapping struct and array to satisfy interface - // needs. if (!type.hasStaticShape()) { + // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing + // to the element. + if (targetEnv.allows(spirv::Capability::Kernel)) + return spirv::PointerType::get(arrayElemType, storageClass); int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); + // For Vulkan we need extra wrapping struct and array to satisfy interface + // needs. return wrapInStructAndGetPointer(arrayType, storageClass); } @@ -354,7 +355,8 @@ auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize); int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); - + if (targetEnv.allows(spirv::Capability::Kernel)) + return spirv::PointerType::get(arrayType, storageClass); return wrapInStructAndGetPointer(arrayType, storageClass); } @@ -403,15 +405,16 @@ return nullptr; } - // For OpenCL Kernel we can just emit a pointer pointing to the element. - if (targetEnv.allows(spirv::Capability::Kernel)) - return spirv::PointerType::get(arrayElemType, storageClass); - // For Vulkan we need extra wrapping struct and array to satisfy interface - // needs. if (!type.hasStaticShape()) { + // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing + // to the element. + if (targetEnv.allows(spirv::Capability::Kernel)) + return spirv::PointerType::get(arrayElemType, storageClass); int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); + // For Vulkan we need extra wrapping struct and array to satisfy interface + // needs. return wrapInStructAndGetPointer(arrayType, storageClass); } @@ -425,7 +428,8 @@ auto arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize); int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); - + if (targetEnv.allows(spirv::Capability::Kernel)) + return spirv::PointerType::get(arrayType, storageClass); return wrapInStructAndGetPointer(arrayType, storageClass); } @@ -776,15 +780,20 @@ auto indexType = typeConverter.getIndexType(); SmallVector linearizedIndices; - auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); - Value linearIndex; if (baseType.getRank() == 0) { - linearIndex = zero; + linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder); } else { linearIndex = linearizeIndex(indices, strides, offset, indexType, loc, builder); } + Type pointeeType = + basePtr.getType().cast().getPointeeType(); + if (pointeeType.isa()) { + linearizedIndices.push_back(linearIndex); + return builder.create(loc, basePtr, + linearizedIndices); + } return builder.create(loc, basePtr, linearIndex, linearizedIndices); } diff --git a/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir b/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir --- a/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir @@ -9,7 +9,7 @@ // CHECK: spirv.func // CHECK-SAME: {{%.*}}: f32 // CHECK-NOT: spirv.interface_var_abi - // CHECK-SAME: {{%.*}}: !spirv.ptr + // CHECK-SAME: {{%.*}}: !spirv.ptr, CrossWorkgroup> // CHECK-NOT: spirv.interface_var_abi // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi : vector<3xi32>> gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, #spirv.storage_class>) kernel diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir @@ -155,3 +155,27 @@ return } } + +// ----- +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> + } +{ + func.func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) { + %0 = memref.alloc() : memref<4x5xf32, #spirv.storage_class> + %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, #spirv.storage_class> + memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, #spirv.storage_class> + memref.dealloc %0 : memref<4x5xf32, #spirv.storage_class> + return + } +} +// CHECK: spirv.GlobalVariable @[[VAR:.+]] : !spirv.ptr, Workgroup> +// CHECK: func @alloc_dealloc_workgroup_mem +// CHECK-NOT: memref.alloc +// CHECK: %[[PTR:.+]] = spirv.mlir.addressof @[[VAR]] +// CHECK: %[[LOADPTR:.+]] = spirv.AccessChain %[[PTR]] +// CHECK: %[[VAL:.+]] = spirv.Load "Workgroup" %[[LOADPTR]] : f32 +// CHECK: %[[STOREPTR:.+]] = spirv.AccessChain %[[PTR]] +// CHECK: spirv.Store "Workgroup" %[[STOREPTR]], %[[VAL]] : f32 +// CHECK-NOT: memref.dealloc diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -121,16 +121,16 @@ // CHECK-LABEL: @load_store_zero_rank_float func.func @load_store_zero_rank_float(%arg0: memref>, %arg1: memref>) { - // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spirv.ptr - // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spirv.ptr + // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spirv.ptr, CrossWorkgroup> + // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spirv.ptr, CrossWorkgroup> // CHECK: [[ZERO1:%.*]] = spirv.Constant 0 : i32 - // CHECK: spirv.PtrAccessChain [[ARG0]][ + // CHECK: spirv.AccessChain [[ARG0]][ // CHECK-SAME: [[ZERO1]] // CHECK-SAME: ] : // CHECK: spirv.Load "CrossWorkgroup" %{{.*}} : f32 %0 = memref.load %arg0[] : memref> // CHECK: [[ZERO2:%.*]] = spirv.Constant 0 : i32 - // CHECK: spirv.PtrAccessChain [[ARG1]][ + // CHECK: spirv.AccessChain [[ARG1]][ // CHECK-SAME: [[ZERO2]] // CHECK-SAME: ] : // CHECK: spirv.Store "CrossWorkgroup" %{{.*}} : f32 @@ -140,16 +140,16 @@ // CHECK-LABEL: @load_store_zero_rank_int func.func @load_store_zero_rank_int(%arg0: memref>, %arg1: memref>) { - // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spirv.ptr - // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spirv.ptr + // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spirv.ptr, CrossWorkgroup> + // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spirv.ptr, CrossWorkgroup> // CHECK: [[ZERO1:%.*]] = spirv.Constant 0 : i32 - // CHECK: spirv.PtrAccessChain [[ARG0]][ + // CHECK: spirv.AccessChain [[ARG0]][ // CHECK-SAME: [[ZERO1]] // CHECK-SAME: ] : // CHECK: spirv.Load "CrossWorkgroup" %{{.*}} : i32 %0 = memref.load %arg0[] : memref> // CHECK: [[ZERO2:%.*]] = spirv.Constant 0 : i32 - // CHECK: spirv.PtrAccessChain [[ARG1]][ + // CHECK: spirv.AccessChain [[ARG1]][ // CHECK-SAME: [[ZERO2]] // CHECK-SAME: ] : // CHECK: spirv.Store "CrossWorkgroup" %{{.*}} : i32 @@ -173,14 +173,13 @@ // CHECK-LABEL: func @load_i1 // CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class>, %[[IDX:.+]]: index) func.func @load_i1(%src: memref<4xi1, #spirv.storage_class>, %i : index) -> i1 { - // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class> to !spirv.ptr + // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class> to !spirv.ptr, CrossWorkgroup> // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] // CHECK: %[[ZERO_0:.+]] = spirv.Constant 0 : i32 - // CHECK: %[[ZERO_1:.+]] = spirv.Constant 0 : i32 // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32 - // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_1]], %[[MUL]] : i32 - // CHECK: %[[ADDR:.+]] = spirv.PtrAccessChain %[[SRC_CAST]][%[[ADD]]] + // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_0]], %[[MUL]] : i32 + // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ADD]]] // CHECK: %[[VAL:.+]] = spirv.Load "CrossWorkgroup" %[[ADDR]] : i8 // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8 // CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8 @@ -194,14 +193,13 @@ // CHECK-SAME: %[[IDX:.+]]: index func.func @store_i1(%dst: memref<4xi1, #spirv.storage_class>, %i: index) { %true = arith.constant true - // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class> to !spirv.ptr + // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class> to !spirv.ptr, CrossWorkgroup> // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] // CHECK: %[[ZERO_0:.+]] = spirv.Constant 0 : i32 - // CHECK: %[[ZERO_1:.+]] = spirv.Constant 0 : i32 // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32 - // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_1]], %[[MUL]] : i32 - // CHECK: %[[ADDR:.+]] = spirv.PtrAccessChain %[[DST_CAST]][%[[ADD]]] + // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_0]], %[[MUL]] : i32 + // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[ADD]]] // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8 // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8 // CHECK: %[[RES:.+]] = spirv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8