diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp @@ -506,9 +506,14 @@ dstElemVecType.getElementType()) { int64_t count = dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8); - auto castType = - VectorType::get({count}, srcElemVecType.getElementType()); - for (auto &c : components) + + // Make sure not to create 1-element vectors, which are illegal in + // SPIR-V. + Type castType = srcElemVecType.getElementType(); + if (count > 1) + castType = VectorType::get({count}, castType); + + for (Value &c : components) c = rewriter.create(loc, castType, c); } } diff --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir @@ -508,6 +508,47 @@ // ----- +spirv.module Logical GLSL450 { + spirv.GlobalVariable @var01_v2f16 bind(0, 1) {aliased} : !spirv.ptr, stride=4> [0])>, StorageBuffer> + spirv.GlobalVariable @var01_v2f32 bind(0, 1) {aliased} : !spirv.ptr, stride=8> [0])>, StorageBuffer> + + spirv.func @aliased(%index: i32) -> vector<3xf32> "None" { + %c0 = spirv.Constant 0 : i32 + %v0 = spirv.Constant dense<0.0> : vector<3xf32> + %addr0 = spirv.mlir.addressof @var01_v2f16 : !spirv.ptr, stride=4> [0])>, StorageBuffer> + %ac0 = spirv.AccessChain %addr0[%c0, %index] : !spirv.ptr, stride=4> [0])>, StorageBuffer>, i32, i32 + %value0 = spirv.Load "StorageBuffer" %ac0 : vector<2xf16> + + %addr1 = spirv.mlir.addressof @var01_v2f32 : !spirv.ptr, stride=8> [0])>, StorageBuffer> + %ac1 = spirv.AccessChain %addr1[%c0, %index] : !spirv.ptr, stride=8> [0])>, StorageBuffer>, i32, i32 + %value1 = spirv.Load "StorageBuffer" %ac1 : vector<2xf32> + + %val0_as_f32 = spirv.Bitcast %value0 : vector<2xf16> to f32 + + %res = spirv.CompositeConstruct %val0_as_f32, %value1 : (f32, vector<2xf32>) -> vector<3xf32> + + spirv.ReturnValue %res : vector<3xf32> + } +} + +// CHECK-LABEL: spirv.module + +// CHECK: spirv.GlobalVariable @var01_v2f16 bind(0, 1) : !spirv.ptr, stride=4> [0])>, StorageBuffer> +// CHECK: spirv.func @aliased + +// CHECK: %[[LD0:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<2xf16> +// CHECK: %[[LD1:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<2xf16> +// CHECK: %[[LD2:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<2xf16> + +// CHECK-DAG: %[[ELEM0:.+]] = spirv.Bitcast %[[LD0]] : vector<2xf16> +// CHECK-DAG: %[[ELEM1:.+]] = spirv.Bitcast %[[LD1]] : vector<2xf16> +// CHECK-DAG: %[[ELEM2:.+]] = spirv.Bitcast %[[LD2]] : vector<2xf16> + +// CHECK: %[[RES:.+]] = spirv.CompositeConstruct %[[ELEM0]], %{{.+}} : (f32, vector<2xf32>) -> vector<3xf32> +// CHECK: spirv.ReturnValue %[[RES]] : vector<3xf32> + +// ----- + // Make sure we do not crash on function arguments. spirv.module Logical GLSL450 {