diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp @@ -366,7 +366,6 @@ } Location loc = acOp.getLoc(); - auto i32Type = rewriter.getI32Type(); if (srcElemType.isIntOrFloat() && dstElemType.isa()) { // The source indices are for a buffer with scalar element types. Rewrite @@ -376,16 +375,19 @@ int srcNumBytes = *srcElemType.getSizeInBytes(); int dstNumBytes = *dstElemType.getSizeInBytes(); assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0); - int ratio = dstNumBytes / srcNumBytes; - auto ratioValue = rewriter.create( - loc, i32Type, rewriter.getI32IntegerAttr(ratio)); auto indices = llvm::to_vector<4>(acOp.getIndices()); Value oldIndex = indices.back(); + Type indexType = oldIndex.getType(); + + int ratio = dstNumBytes / srcNumBytes; + auto ratioValue = rewriter.create( + loc, indexType, rewriter.getIntegerAttr(indexType, ratio)); + indices.back() = - rewriter.create(loc, i32Type, oldIndex, ratioValue); + rewriter.create(loc, indexType, oldIndex, ratioValue); indices.push_back( - rewriter.create(loc, i32Type, oldIndex, ratioValue)); + rewriter.create(loc, indexType, oldIndex, ratioValue)); rewriter.replaceOpWithNewOp( acOp, adaptor.getBasePtr(), indices); @@ -400,14 +402,17 @@ int srcNumBytes = *srcElemType.getSizeInBytes(); int dstNumBytes = *dstElemType.getSizeInBytes(); assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0); - int ratio = srcNumBytes / dstNumBytes; - auto ratioValue = rewriter.create( - loc, i32Type, rewriter.getI32IntegerAttr(ratio)); auto indices = llvm::to_vector<4>(acOp.getIndices()); Value oldIndex = indices.back(); + Type indexType = oldIndex.getType(); + + int ratio = srcNumBytes / dstNumBytes; + auto ratioValue = rewriter.create( + loc, indexType, rewriter.getIntegerAttr(indexType, ratio)); + indices.back() = - rewriter.create(loc, i32Type, oldIndex, ratioValue); + rewriter.create(loc, indexType, oldIndex, ratioValue); rewriter.replaceOpWithNewOp( acOp, adaptor.getBasePtr(), indices); diff --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir @@ -32,6 +32,33 @@ // ----- +spirv.module Logical GLSL450 { + spirv.GlobalVariable @var01s bind(0, 1) {aliased} : !spirv.ptr [0])>, StorageBuffer> + spirv.GlobalVariable @var01v bind(0, 1) {aliased} : !spirv.ptr, stride=16> [0])>, StorageBuffer> + + spirv.func @load_store_scalar_64bit(%index: i64) -> f32 "None" { + %c0 = spirv.Constant 0 : i64 + %addr = spirv.mlir.addressof @var01s : !spirv.ptr [0])>, StorageBuffer> + %ac = spirv.AccessChain %addr[%c0, %index] : !spirv.ptr [0])>, StorageBuffer>, i64, i64 + %value = spirv.Load "StorageBuffer" %ac : f32 + spirv.Store "StorageBuffer" %ac, %value : f32 + spirv.ReturnValue %value : f32 + } +} + +// CHECK-LABEL: spirv.module + +// CHECK-NOT: @var01s +// CHECK: spirv.GlobalVariable @var01v bind(0, 1) : !spirv.ptr, stride=16> [0])>, StorageBuffer> +// CHECK-NOT: @var01s + +// CHECK: spirv.func @load_store_scalar_64bit(%[[INDEX:.+]]: i64) +// CHECK-DAG: %[[C4:.+]] = spirv.Constant 4 : i64 +// CHECK: spirv.SDiv %[[INDEX]], %[[C4]] : i64 +// CHECK: spirv.SMod %[[INDEX]], %[[C4]] : i64 + +// ----- + spirv.module Logical GLSL450 { spirv.GlobalVariable @var01s bind(0, 1) {aliased} : !spirv.ptr [0])>, StorageBuffer> spirv.GlobalVariable @var01v bind(0, 1) {aliased} : !spirv.ptr, stride=16> [0])>, StorageBuffer>