diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -173,8 +173,9 @@ auto idx = builder.create(loc, targetType, attr); auto lastDim = op->getOperand(op.getNumOperands() - 1); auto indices = llvm::to_vector<4>(op.indices()); - // There are two elements if this is a 1-D tensor. - assert(indices.size() == 2); + // There are at most two elements (one implicit zero for indexing into + // wrapping struct in interface storage classes) if this is a 1-D tensor. + assert(indices.size() <= 2); indices.back() = builder.create(loc, lastDim, idx); Type t = typeConverter.convertType(op.component_ptr().getType()); return builder.create(loc, t, op.base_ptr(), indices); @@ -989,13 +990,12 @@ loadOperands.indices(), loc, rewriter); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); - auto dstType = typeConverter.convertType(memrefType) + Type dstType = typeConverter.convertType(memrefType) .cast() - .getPointeeType() - .cast() - .getElementType(0) - .cast() - .getElementType(); + .getPointeeType(); + if (dstType.isa()) + dstType = dstType.cast().getElementType(0); + dstType = dstType.cast().getElementType(); int dstBits = dstType.getIntOrFloatBitWidth(); assert(dstBits % srcBits == 0); @@ -1010,7 +1010,10 @@ // Assume that getElementPtr() works linearizely. If it's a scalar, the method // still returns a linearized accessing. If the accessing is not linearized, // there will be offset issues. - assert(accessChainOp.indices().size() == 2); + + // There are at most two indices (one implicit zero for indexing into + // wrapping struct in interface storage classes). + assert(accessChainOp.indices().size() <= 2); Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, srcBits, dstBits, rewriter); Value spvLoadOp = rewriter.create( @@ -1114,13 +1117,12 @@ spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), storeOperands.indices(), loc, rewriter); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); - auto dstType = typeConverter.convertType(memrefType) + Type dstType = typeConverter.convertType(memrefType) .cast() - .getPointeeType() - .cast() - .getElementType(0) - .cast() - .getElementType(); + .getPointeeType(); + if (dstType.isa()) + dstType = dstType.cast().getElementType(0); + dstType = dstType.cast().getElementType(); int dstBits = dstType.getIntOrFloatBitWidth(); assert(dstBits % srcBits == 0); @@ -1141,7 +1143,10 @@ // 6) store 32-bit value back // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step // 4 to step 6 are done by AtomicOr as another atomic step. - assert(accessChainOp.indices().size() == 2); + + // There are at most two indices (one implicit zero for indexing into + // wrapping struct in interface storage classes). + assert(accessChainOp.indices().size() <= 2); Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -82,6 +82,20 @@ return success(); } +/// Returns true if the given `storageClass` needs explicit layout when used in +/// Shader environments. +static bool needsExplicitLayout(spirv::StorageClass storageClass) { + switch (storageClass) { + case spirv::StorageClass::PhysicalStorageBuffer: + case spirv::StorageClass::PushConstant: + case spirv::StorageClass::StorageBuffer: + case spirv::StorageClass::Uniform: + return true; + default: + return false; + } +} + //===----------------------------------------------------------------------===// // Type Conversion //===----------------------------------------------------------------------===// @@ -401,12 +415,12 @@ auto arrayType = spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize); - // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with - // workgroup storage class do not need the struct to be laid out explicitly. - auto structType = *storageClass == spirv::StorageClass::Workgroup - ? spirv::StructType::get(arrayType) - : spirv::StructType::get(arrayType, 0); - return spirv::PointerType::get(structType, *storageClass); + Type resultType = arrayType; + if (needsExplicitLayout(*storageClass)) { + // Wrap in a struct to satisfy Vulkan interface requirements. + resultType = spirv::StructType::get(resultType, 0); + } + return spirv::PointerType::get(resultType, *storageClass); } SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr) @@ -650,8 +664,15 @@ SmallVector linearizedIndices; auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); - // Add a '0' at the start to index into the struct. - linearizedIndices.push_back(zero); + auto storageClass = SPIRVTypeConverter::getStorageClassForMemorySpace( + baseType.getMemorySpaceAsInt()); + if (!storageClass) + return {}; + + if (needsExplicitLayout(*storageClass)) { + // Add a '0' at the start to index into the struct. + linearizedIndices.push_back(zero); + } if (baseType.getRank() == 0) { linearizedIndices.push_back(zero); diff --git a/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir b/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir --- a/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir @@ -9,7 +9,7 @@ // CHECK: spv.func // CHECK-SAME: {{%.*}}: f32 // CHECK-NOT: spv.interface_var_abi - // CHECK-SAME: {{%.*}}: !spv.ptr [0])>, CrossWorkgroup> + // CHECK-SAME: {{%.*}}: !spv.ptr, CrossWorkgroup> // CHECK-NOT: spv.interface_var_abi // CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>} gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, 11>) kernel diff --git a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir --- a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir @@ -17,7 +17,7 @@ return } } -// CHECK: spv.GlobalVariable @[[VAR:.+]] : !spv.ptr)>, Workgroup> +// CHECK: spv.GlobalVariable @[[VAR:.+]] : !spv.ptr, Workgroup> // CHECK: func @alloc_dealloc_workgroup_mem // CHECK-NOT: alloc // CHECK: %[[PTR:.+]] = spv.mlir.addressof @[[VAR]] @@ -45,14 +45,14 @@ } // CHECK: spv.GlobalVariable @__workgroup_mem__{{[0-9]+}} -// CHECK-SAME: !spv.ptr)>, Workgroup> +// CHECK-SAME: !spv.ptr, Workgroup> // CHECK_LABEL: spv.func @alloc_dealloc_workgroup_mem // CHECK: %[[VAR:.+]] = spv.mlir.addressof @__workgroup_mem__0 // CHECK: %[[LOC:.+]] = spv.SDiv -// CHECK: %[[PTR:.+]] = spv.AccessChain %[[VAR]][%{{.+}}, %[[LOC]]] +// CHECK: %[[PTR:.+]] = spv.AccessChain %[[VAR]][%[[LOC]]] // CHECK: %{{.+}} = spv.Load "Workgroup" %[[PTR]] : i32 // CHECK: %[[LOC:.+]] = spv.SDiv -// CHECK: %[[PTR:.+]] = spv.AccessChain %[[VAR]][%{{.+}}, %[[LOC]]] +// CHECK: %[[PTR:.+]] = spv.AccessChain %[[VAR]][%[[LOC]]] // CHECK: %{{.+}} = spv.AtomicAnd "Workgroup" "AcquireRelease" %[[PTR]], %{{.+}} : !spv.ptr // CHECK: %{{.+}} = spv.AtomicOr "Workgroup" "AcquireRelease" %[[PTR]], %{{.+}} : !spv.ptr @@ -72,9 +72,9 @@ } // CHECK-DAG: spv.GlobalVariable @__workgroup_mem__{{[0-9]+}} -// CHECK-SAME: !spv.ptr)>, Workgroup> +// CHECK-SAME: !spv.ptr, Workgroup> // CHECK-DAG: spv.GlobalVariable @__workgroup_mem__{{[0-9]+}} -// CHECK-SAME: !spv.ptr)>, Workgroup> +// CHECK-SAME: !spv.ptr, Workgroup> // CHECK: spv.func @two_allocs() // CHECK: spv.Return @@ -93,9 +93,9 @@ } // CHECK-DAG: spv.GlobalVariable @__workgroup_mem__{{[0-9]+}} -// CHECK-SAME: !spv.ptr, stride=8>)>, Workgroup> +// CHECK-SAME: !spv.ptr, stride=8>, Workgroup> // CHECK-DAG: spv.GlobalVariable @__workgroup_mem__{{[0-9]+}} -// CHECK-SAME: !spv.ptr, stride=16>)>, Workgroup> +// CHECK-SAME: !spv.ptr, stride=16>, Workgroup> // CHECK: spv.func @two_allocs_vector() // CHECK: spv.Return diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir @@ -299,12 +299,12 @@ func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return } // CHECK-LABEL: spv.func @memref_16bit_Input -// CHECK-SAME: !spv.ptr [0])>, Input> -func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return } +// CHECK-SAME: !spv.ptr, Input> +func @memref_16bit_Input(%arg3: memref<4xf16, 9>) { return } // CHECK-LABEL: spv.func @memref_16bit_Output -// CHECK-SAME: !spv.ptr [0])>, Output> -func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return } +// CHECK-SAME: !spv.ptr, Output> +func @memref_16bit_Output(%arg4: memref<4xf16, 10>) { return } } // end module @@ -394,12 +394,12 @@ } { // CHECK-LABEL: spv.func @memref_16bit_Input -// CHECK-SAME: !spv.ptr [0])>, Input> -func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return } +// CHECK-SAME: !spv.ptr, Input> +func @memref_16bit_Input(%arg3: memref<4xf16, 9>) { return } // CHECK-LABEL: spv.func @memref_16bit_Output -// CHECK-SAME: !spv.ptr [0])>, Output> -func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return } +// CHECK-SAME: !spv.ptr, Output> +func @memref_16bit_Output(%arg4: memref<4xi16, 10>) { return } } // end module