diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp @@ -78,6 +78,7 @@ static Optional deduceCanonicalResource(ArrayRef types) { // scalarNumBits: contains all resources' scalar types' bit counts. // vectorNumBits: only contains resources whose element types are vectors. + // vectorIndices: each vector's original index in `types`. SmallVector scalarNumBits, vectorNumBits, vectorIndices; scalarNumBits.reserve(types.size()); vectorNumBits.reserve(types.size()); @@ -104,11 +105,6 @@ } if (!vectorNumBits.empty()) { - // If there are vector types, require all element types to be the same for - // now to simplify the transformation. - if (!llvm::is_splat(scalarNumBits)) - return llvm::None; - // Choose the *vector* with the smallest bitwidth as the canonical resource, // so that we can still keep vectorized load/store and avoid partial updates // to large vectors. @@ -116,10 +112,18 @@ // Make sure that the canonical resource's bitwidth is divisible by others. // With out this, we cannot properly adjust the index later. if (llvm::any_of(vectorNumBits, - [minVal](int64_t bits) { return bits % *minVal != 0; })) + [&](int bits) { return bits % *minVal != 0; })) + return llvm::None; + + // Require all scalar type bit counts to be a multiple of the chosen + // vector's primitive type to avoid reading/writing subcomponents. + int index = vectorIndices[std::distance(vectorNumBits.begin(), minVal)]; + int baseNumBits = scalarNumBits[index]; + if (llvm::any_of(scalarNumBits, + [&](int bits) { return bits % baseNumBits != 0; })) return llvm::None; - return vectorIndices[std::distance(vectorNumBits.begin(), minVal)]; + return index; } // All element types are scalars. Then choose the smallest bitwidth as the @@ -357,10 +361,10 @@ // them into a buffer with vector element types. We need to scale the last // index for the vector as a whole, then add one level of index for inside // the vector. - int srcNumBits = *srcElemType.getSizeInBytes(); - int dstNumBits = *dstElemType.getSizeInBytes(); - assert(dstNumBits > srcNumBits && dstNumBits % srcNumBits == 0); - int ratio = dstNumBits / srcNumBits; + int srcNumBytes = *srcElemType.getSizeInBytes(); + int dstNumBytes = *dstElemType.getSizeInBytes(); + assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0); + int ratio = dstNumBytes / srcNumBytes; auto ratioValue = rewriter.create( loc, i32Type, rewriter.getI32IntegerAttr(ratio)); @@ -381,10 +385,10 @@ // The source indices are for a buffer with larger bitwidth scalar/vector // element types. Rewrite them into a buffer with smaller bitwidth element // types. We only need to scale the last index. - int srcNumBits = *srcElemType.getSizeInBytes(); - int dstNumBits = *dstElemType.getSizeInBytes(); - assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0); - int ratio = srcNumBits / dstNumBits; + int srcNumBytes = *srcElemType.getSizeInBytes(); + int dstNumBytes = *dstElemType.getSizeInBytes(); + assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0); + int ratio = srcNumBytes / dstNumBytes; auto ratioValue = rewriter.create( loc, i32Type, rewriter.getI32IntegerAttr(ratio)); @@ -435,10 +439,10 @@ // vector types of different component counts. For such cases, we load // multiple smaller bitwidth values and construct a larger bitwidth one. - int srcNumBits = *srcElemType.getSizeInBytes() * 8; - int dstNumBits = *dstElemType.getSizeInBytes() * 8; - assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0); - int ratio = srcNumBits / dstNumBits; + int srcNumBytes = *srcElemType.getSizeInBytes(); + int dstNumBytes = *dstElemType.getSizeInBytes(); + assert(srcNumBytes > dstNumBytes && srcNumBytes % dstNumBytes == 0); + int ratio = srcNumBytes / dstNumBytes; if (ratio > 4) return rewriter.notifyMatchFailure(loadOp, "more than 4 components"); diff --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir @@ -189,7 +189,7 @@ spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr, stride=16> [0])>, StorageBuffer> - spv.func @different_scalar_type(%index: i32, %val0: i32) -> i32 "None" { + spv.func @different_primitive_type(%index: i32, %val0: i32) -> i32 "None" { %c0 = spv.Constant 0 : i32 %addr = spv.mlir.addressof @var01s : !spv.ptr [0])>, StorageBuffer> %ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr [0])>, StorageBuffer>, i32, i32 @@ -205,7 +205,7 @@ // CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr, stride=16> [0])>, StorageBuffer> // CHECK-NOT: @var01s -// CHECK: spv.func @different_scalar_type(%{{.+}}: i32, %[[VAL0:.+]]: i32) +// CHECK: spv.func @different_primitive_type(%{{.+}}: i32, %[[VAL0:.+]]: i32) // CHECK: %[[ADDR:.+]] = spv.mlir.addressof @var01v // CHECK: %[[AC:.+]] = spv.AccessChain %[[ADDR]][%{{.+}}, %{{.+}}, %{{.+}}] // CHECK: %[[VAL1:.+]] = spv.Load "StorageBuffer" %[[AC]] : f32 @@ -329,3 +329,122 @@ // CHECK: %[[MOD:.+]] = spv.SMod %[[IDX]], %[[TWO]] : i32 // CHECK: %[[AC:.+]] = spv.AccessChain %[[ADDR]][%[[ZERO]], %[[DIV]], %[[MOD]]] // CHECK: %[[LD:.+]] = spv.Load "StorageBuffer" %[[AC]] : f32 + +// ----- + +spv.module Logical GLSL450 { + spv.GlobalVariable @var01_v4f32 bind(0, 1) {aliased} : !spv.ptr, stride=16> [0])>, StorageBuffer> + spv.GlobalVariable @var01_f32 bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + spv.GlobalVariable @var01_i64 bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + + spv.func @load_mixed_scalar_vector_primitive_types(%i0: i32) -> vector<4xf32> "None" { + %c0 = spv.Constant 0 : i32 + + %addr0 = spv.mlir.addressof @var01_v4f32 : !spv.ptr, stride=16> [0])>, StorageBuffer> + %ac0 = spv.AccessChain %addr0[%c0, %i0] : !spv.ptr, stride=16> [0])>, StorageBuffer>, i32, i32 + %vec4val = spv.Load "StorageBuffer" %ac0 : vector<4xf32> + + %addr1 = spv.mlir.addressof @var01_f32 : !spv.ptr [0])>, StorageBuffer> + %ac1 = spv.AccessChain %addr1[%c0, %i0] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + %f32val = spv.Load "StorageBuffer" %ac1 : f32 + + %addr2 = spv.mlir.addressof @var01_i64 : !spv.ptr [0])>, StorageBuffer> + %ac2 = spv.AccessChain %addr2[%c0, %i0] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + %i64val = spv.Load "StorageBuffer" %ac2 : i64 + %i32val = spv.SConvert %i64val : i64 to i32 + %castval = spv.Bitcast %i32val : i32 to f32 + + %val1 = spv.CompositeInsert %f32val, %vec4val[0 : i32] : f32 into vector<4xf32> + %val2 = spv.CompositeInsert %castval, %val1[1 : i32] : f32 into vector<4xf32> + spv.ReturnValue %val2 : vector<4xf32> + } +} + +// CHECK-LABEL: spv.module + +// CHECK-NOT: @var01_f32 +// CHECK-NOT: @var01_i64 +// CHECK: spv.GlobalVariable @var01_v4f32 bind(0, 1) : !spv.ptr<{{.+}}> +// CHECK-NOT: @var01_f32 +// CHECK-NOT: @var01_i64 + +// CHECK: spv.func @load_mixed_scalar_vector_primitive_types(%[[IDX:.+]]: i32) + +// CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 +// CHECK: %[[ADDR0:.+]] = spv.mlir.addressof @var01_v4f32 +// CHECK: %[[AC0:.+]] = spv.AccessChain %[[ADDR0]][%[[ZERO]], %[[IDX]]] +// CHECK: spv.Load "StorageBuffer" %[[AC0]] : vector<4xf32> + +// CHECK: %[[ADDR1:.+]] = spv.mlir.addressof @var01_v4f32 +// CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32 +// CHECK: %[[DIV:.+]] = spv.SDiv %[[IDX]], %[[FOUR]] : i32 +// CHECK: %[[MOD:.+]] = spv.SMod %[[IDX]], %[[FOUR]] : i32 +// CHECK: %[[AC1:.+]] = spv.AccessChain %[[ADDR1]][%[[ZERO]], %[[DIV]], %[[MOD]]] +// CHECK: spv.Load "StorageBuffer" %[[AC1]] : f32 + +// CHECK: %[[ADDR2:.+]] = spv.mlir.addressof @var01_v4f32 +// CHECK: %[[TWO:.+]] = spv.Constant 2 : i32 +// CHECK: %[[DIV0:.+]] = spv.SDiv %[[IDX]], %[[TWO]] : i32 +// CHECK: %[[MOD0:.+]] = spv.SMod %[[IDX]], %[[TWO]] : i32 +// CHECK: %[[AC2:.+]] = spv.AccessChain %[[ADDR2]][%[[ZERO]], %[[DIV0]], %[[MOD0]]] +// CHECK: %[[LD0:.+]] = spv.Load "StorageBuffer" %[[AC2]] : f32 + +// CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 +// CHECK: %[[MOD1:.+]] = spv.IAdd %[[MOD0]], %[[ONE]] +// CHECK: %[[AC3:.+]] = spv.AccessChain %[[ADDR2]][%[[ZERO]], %[[DIV0]], %[[MOD1]]] +// CHECK: %[[LD1:.+]] = spv.Load "StorageBuffer" %[[AC3]] : f32 +// CHECK: %[[CC:.+]] = spv.CompositeConstruct %[[LD0]], %[[LD1]] +// CHECK: %[[BC:.+]] = spv.Bitcast %[[CC]] : vector<2xf32> to i64 + +// ----- + +spv.module Logical GLSL450 { + spv.GlobalVariable @var01_v2f2 bind(0, 1) {aliased} : !spv.ptr, stride=16> [0])>, StorageBuffer> + spv.GlobalVariable @var01_i64 bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + + spv.func @load_mixed_scalar_vector_primitive_types(%i0: i32) -> i64 "None" { + %c0 = spv.Constant 0 : i32 + + %addr = spv.mlir.addressof @var01_i64 : !spv.ptr [0])>, StorageBuffer> + %ac = spv.AccessChain %addr[%c0, %i0] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + %val = spv.Load "StorageBuffer" %ac : i64 + + spv.ReturnValue %val : i64 + } +} + +// CHECK-LABEL: spv.module + +// CHECK: spv.func @load_mixed_scalar_vector_primitive_types(%[[IDX:.+]]: i32) + +// CHECK: %[[ADDR:.+]] = spv.mlir.addressof @var01_v2f2 +// CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 +// CHECK: %[[DIV:.+]] = spv.SDiv %[[IDX]], %[[ONE]] : i32 +// CHECK: %[[MOD:.+]] = spv.SMod %[[IDX]], %[[ONE]] : i32 +// CHECK: spv.AccessChain %[[ADDR]][%{{.+}}, %[[DIV]], %[[MOD]]] +// CHECK: spv.Load +// CHECK: spv.Load + +// ----- + +spv.module Logical GLSL450 { + spv.GlobalVariable @var01_v2f2 bind(0, 1) {aliased} : !spv.ptr, stride=16> [0])>, StorageBuffer> + spv.GlobalVariable @var01_i16 bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + + spv.func @scalar_type_bitwidth_smaller_than_vector(%i0: i32) -> i16 "None" { + %c0 = spv.Constant 0 : i32 + + %addr = spv.mlir.addressof @var01_i16 : !spv.ptr [0])>, StorageBuffer> + %ac = spv.AccessChain %addr[%c0, %i0] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + %val = spv.Load "StorageBuffer" %ac : i16 + + spv.ReturnValue %val : i16 + } +} + +// CHECK-LABEL: spv.module + +// CHECK: spv.GlobalVariable @var01_v2f2 bind(0, 1) {aliased} +// CHECK: spv.GlobalVariable @var01_i16 bind(0, 1) {aliased} + +// CHECK: spv.func @scalar_type_bitwidth_smaller_than_vector