diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp @@ -485,11 +485,28 @@ // bitwidth element type. For spirv.bitcast, the lower-numbered components // of the vector map to lower-ordered bits of the larger bitwidth element // type. + Type vectorType = srcElemType; if (!srcElemType.isa()) vectorType = VectorType::get({ratio}, dstElemType); + + // If both the source and destination are vector types, we need to make + // sure the scalar type is the same for composite construction later. + if (auto srcElemVecType = srcElemType.dyn_cast()) + if (auto dstElemVecType = dstElemType.dyn_cast()) { + if (srcElemVecType.getElementType() != + dstElemVecType.getElementType()) { + int64_t count = + dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8); + auto castType = + VectorType::get({count}, srcElemVecType.getElementType()); + for (auto &c : components) + c = rewriter.create(loc, castType, c); + } + } Value vectorValue = rewriter.create( loc, vectorType, components); + if (!srcElemType.isa()) vectorValue = rewriter.create(loc, srcElemType, vectorValue); diff --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir @@ -448,3 +448,34 @@ // CHECK: spirv.GlobalVariable @var01_i16 bind(0, 1) {aliased} // CHECK: spirv.func @scalar_type_bitwidth_smaller_than_vector + +// ----- + +spirv.module Logical GLSL450 { + spirv.GlobalVariable @var00_v4f32 bind(0, 0) {aliased} : !spirv.ptr, stride=16> [0])>, StorageBuffer> + spirv.GlobalVariable @var00_v4f16 bind(0, 0) {aliased} : !spirv.ptr, stride=8> [0])>, StorageBuffer> + + spirv.func @vector_type_same_size_different_element_type(%i0: i32) -> vector<4xf32> "None" { + %c0 = spirv.Constant 0 : i32 + + %addr = spirv.mlir.addressof @var00_v4f32 : !spirv.ptr, stride=16> [0])>, StorageBuffer> + %ac = spirv.AccessChain %addr[%c0, %i0] : !spirv.ptr, stride=16> [0])>, StorageBuffer>, i32, i32 + %val = spirv.Load "StorageBuffer" %ac : vector<4xf32> + + spirv.ReturnValue %val : vector<4xf32> + } +} + +// CHECK-LABEL: spirv.module + +// CHECK: spirv.GlobalVariable @var00_v4f16 bind(0, 0) : !spirv.ptr, stride=8> [0])>, StorageBuffer> + +// CHECK: spirv.func @vector_type_same_size_different_element_type + +// CHECK: %[[LD0:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf16> +// CHECK: %[[LD1:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf16> +// CHECK: %[[BC0:.+]] = spirv.Bitcast %[[LD0]] : vector<4xf16> to vector<2xf32> +// CHECK: %[[BC1:.+]] = spirv.Bitcast %[[LD1]] : vector<4xf16> to vector<2xf32> +// CHECK: %[[CC:.+]] = spirv.CompositeConstruct %[[BC0]], %[[BC1]] : (vector<2xf32>, vector<2xf32>) -> vector<4xf32> +// CHECK: spirv.ReturnValue %[[CC]] +