diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp @@ -220,6 +220,9 @@ } bool ResourceAliasAnalysis::shouldUnify(Operation *op) const { + if (!op) + return false; + if (auto varOp = dyn_cast(op)) { auto canonicalOp = getCanonicalResource(varOp); return canonicalOp && varOp != canonicalOp; @@ -566,16 +569,15 @@ private: spirv::GetTargetEnvFn getTargetEnvFn; }; -} // namespace void UnifyAliasedResourcePass::runOnOperation() { spirv::ModuleOp moduleOp = getOperation(); MLIRContext *context = &getContext(); if (getTargetEnvFn) { - // This pass is only needed for targeting WebGPU, Metal, or layering Vulkan - // on Metal via MoltenVK, where we need to translate SPIR-V into WGSL or - // MSL. The translation has limitations. + // This pass is only needed for targeting WebGPU, Metal, or layering + // Vulkan on Metal via MoltenVK, where we need to translate SPIR-V into + // WGSL or MSL. The translation has limitations. spirv::TargetEnvAttr targetEnv = getTargetEnvFn(moduleOp); spirv::ClientAPI clientAPI = targetEnv.getClientAPI(); bool isVulkanOnAppleDevices = @@ -614,6 +616,7 @@ resources.front()->removeAttr("aliased"); } } +} // namespace std::unique_ptr> spirv::createUnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv) { diff --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir @@ -506,3 +506,19 @@ // CHECK: %[[CC:.+]] = spirv.CompositeConstruct %[[BC0]], %[[BC1]] : (vector<2xf32>, vector<2xf32>) -> vector<4xf32> // CHECK: spirv.ReturnValue %[[CC]] +// ----- + +// Make sure we do not crash on function arguments. + +spirv.module Logical GLSL450 { + spirv.func @main(%arg0: !spirv.ptr [0])>, StorageBuffer>) "None" { + %cst0_i32 = spirv.Constant 0 : i32 + %0 = spirv.AccessChain %arg0[%cst0_i32, %cst0_i32] : !spirv.ptr [0])>, StorageBuffer>, i32, i32 + spirv.Return + } +} + +// CHECK-LABEL: spirv.module +// CHECK-LABEL: spirv.func @main +// CHECK-SAME: (%{{.+}}: !spirv.ptr [0])>, StorageBuffer>) "None" +// CHECK: spirv.Return