diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td @@ -385,7 +385,7 @@ OptionalAttr:$initializer, OptionalAttr:$location, OptionalAttr:$binding, - OptionalAttr:$descriptorSet, + OptionalAttr:$descriptor_set, OptionalAttr:$builtin ); diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h @@ -55,6 +55,11 @@ /// spv.CompositeInsert into spv.CompositeConstruct. std::unique_ptr> createRewriteInsertsPass(); +/// Creates an operation pass that unifies access of multiple aliased resources +/// into access of one single resource. +std::unique_ptr> +createUnifyAliasedResourcePass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td @@ -28,6 +28,13 @@ let constructor = "mlir::spirv::createRewriteInsertsPass()"; } +def SPIRVUnifyAliasedResourcePass + : Pass<"spirv-unify-aliased-resource", "spirv::ModuleOp"> { + let summary = "Unify access of multiple aliased resources into access of one " + "single resource"; + let constructor = "mlir::spirv::createUnifyAliasedResourcePass()"; +} + def SPIRVUpdateVCE : Pass<"spirv-update-vce", "spirv::ModuleOp"> { let summary = "Deduce and attach minimal (version, capabilities, extensions) " "requirements to spv.module ops"; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ LowerABIAttributesPass.cpp RewriteInsertsPass.cpp SPIRVConversion.cpp + UnifyAliasedResourcePass.cpp UpdateVCEPass.cpp ) @@ -21,6 +22,7 @@ DecorateCompositeTypeLayoutPass.cpp LowerABIAttributesPass.cpp RewriteInsertsPass.cpp + UnifyAliasedResourcePass.cpp UpdateVCEPass.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp @@ -0,0 +1,452 @@ +//===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass that unifies access of multiple aliased resources +// into access of one single resource. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" +#include "mlir/Dialect/SPIRV/Transforms/Passes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/AnalysisManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include + +#define DEBUG_TYPE "spirv-unify-aliased-resource" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +using Descriptor = std::pair; // (set #, binding #) +using AliasedResourceMap = + DenseMap>; + +/// Collects all aliased resources in the given SPIR-V `moduleOp`. +static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) { + AliasedResourceMap aliasedResoruces; + moduleOp->walk([&aliasedResoruces](spirv::GlobalVariableOp varOp) { + if (varOp->getAttrOfType("aliased")) { + Optional set = varOp.descriptor_set(); + Optional binding = varOp.binding(); + if (set && binding) + aliasedResoruces[{*set, *binding}].push_back(varOp); + } + }); + return aliasedResoruces; +} + +/// Returns the element type if the given `type` is a runtime array resource: +/// `!spv.ptr>>`. Returns null type otherwise. +static Type getRuntimeArrayElementType(Type type) { + auto ptrType = type.dyn_cast(); + if (!ptrType) + return {}; + + auto structType = ptrType.getPointeeType().dyn_cast(); + if (!structType || structType.getNumElements() != 1) + return {}; + + auto rtArrayType = + structType.getElementType(0).dyn_cast(); + if (!rtArrayType) + return {}; + + return rtArrayType.getElementType(); +} + +/// Returns true if all `types`, which can either be scalar or vector types, +/// have the same bitwidth base scalar type. +static bool hasSameBitwidthScalarType(ArrayRef types) { + SmallVector scalarTypes; + scalarTypes.reserve(types.size()); + for (spirv::SPIRVType type : types) { + assert(type.isScalarOrVector()); + if (auto vectorType = type.dyn_cast()) + scalarTypes.push_back( + vectorType.getElementType().getIntOrFloatBitWidth()); + else + scalarTypes.push_back(type.getIntOrFloatBitWidth()); + } + return llvm::is_splat(scalarTypes); +} + +//===----------------------------------------------------------------------===// +// Analysis +//===----------------------------------------------------------------------===// + +namespace { +/// A class for analyzing aliased resources. +/// +/// Resources are expected to be spv.GlobalVarible that has a descriptor set and +/// binding number. Such resources are of the type `!spv.ptr>` +/// per Vulkan requirements. +/// +/// Right now, we only support the case that there is a single runtime array +/// inside the struct. +class ResourceAliasAnalysis { +public: + explicit ResourceAliasAnalysis(Operation *); + + /// Returns true if the given `op` can be rewritten to use a canonical + /// resource. + bool shouldUnify(Operation *op) const; + + /// Returns all descriptors and their corresponding aliased resources. + const AliasedResourceMap &getResourceMap() const { return resourceMap; } + + /// Returns the canonical resource for the given descriptor/variable. + spirv::GlobalVariableOp + getCanonicalResource(const Descriptor &descriptor) const; + spirv::GlobalVariableOp + getCanonicalResource(spirv::GlobalVariableOp varOp) const; + + /// Returns the element type for the given variable. + spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const; + +private: + /// Given the descriptor and aliased resources bound to it, analyze whether we + /// can unify them and record if so. + void recordIfUnifiable(const Descriptor &descriptor, + ArrayRef resources); + + /// Mapping from a descriptor to all aliased resources bound to it. + AliasedResourceMap resourceMap; + + /// Mapping from a descriptor to the chosen canonical resource. + DenseMap canonicalResourceMap; + + /// Mapping from an aliased resource to its descriptor. + DenseMap descriptorMap; + + /// Mapping from an aliased resource to its element (scalar/vector) type. + DenseMap elementTypeMap; +}; +} // namespace + +ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) { + // Collect all aliased resources first and put them into different sets + // according to the descriptor. + AliasedResourceMap aliasedResoruces = + collectAliasedResources(cast(root)); + + // For each resource set, analyze whether we can unify; if so, try to identify + // a canonical resource, whose element type has the largest bitwidth. + for (const auto &descriptorResoruce : aliasedResoruces) { + recordIfUnifiable(descriptorResoruce.first, descriptorResoruce.second); + } +} + +bool ResourceAliasAnalysis::shouldUnify(Operation *op) const { + if (auto varOp = dyn_cast(op)) { + auto canonicalOp = getCanonicalResource(varOp); + return canonicalOp && varOp != canonicalOp; + } + if (auto addressOp = dyn_cast(op)) { + auto moduleOp = addressOp->getParentOfType(); + auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()); + return shouldUnify(varOp); + } + + if (auto acOp = dyn_cast(op)) + return shouldUnify(acOp.base_ptr().getDefiningOp()); + if (auto loadOp = dyn_cast(op)) + return shouldUnify(loadOp.ptr().getDefiningOp()); + if (auto storeOp = dyn_cast(op)) + return shouldUnify(storeOp.ptr().getDefiningOp()); + + return false; +} + +spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource( + const Descriptor &descriptor) const { + auto varIt = canonicalResourceMap.find(descriptor); + if (varIt == canonicalResourceMap.end()) + return {}; + return varIt->second; +} + +spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource( + spirv::GlobalVariableOp varOp) const { + auto descriptorIt = descriptorMap.find(varOp); + if (descriptorIt == descriptorMap.end()) + return {}; + return getCanonicalResource(descriptorIt->second); +} + +spirv::SPIRVType +ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const { + auto it = elementTypeMap.find(varOp); + if (it == elementTypeMap.end()) + return {}; + return it->second; +} + +void ResourceAliasAnalysis::recordIfUnifiable( + const Descriptor &descriptor, ArrayRef resources) { + // Collect the element types and byte counts for all resources in the + // current set. + SmallVector elementTypes; + SmallVector numBytes; + + for (spirv::GlobalVariableOp resource : resources) { + Type elementType = getRuntimeArrayElementType(resource.type()); + if (!elementType) + return; // Unexpected resource variable type. + + auto type = elementType.cast(); + if (!type.isScalarOrVector()) + return; // Unexpected resource element type. + + if (auto vectorType = type.dyn_cast()) + if (vectorType.getNumElements() % 2 != 0) + return; // Odd-sized vector has special layout requirements. + + Optional count = type.getSizeInBytes(); + if (!count) + return; + + elementTypes.push_back(type); + numBytes.push_back(*count); + } + + // Make sure base scalar types have the same bitwdith, so that we don't need + // to handle extracting components for now. + if (!hasSameBitwidthScalarType(elementTypes)) + return; + + // Make sure that the canonical resource's bitwidth is divisible by others. + // With out this, we cannot properly adjust the index later. + auto *maxCount = std::max_element(numBytes.begin(), numBytes.end()); + if (llvm::any_of(numBytes, [maxCount](int64_t count) { + return *maxCount % count != 0; + })) + return; + + spirv::GlobalVariableOp canonicalResource = + resources[std::distance(numBytes.begin(), maxCount)]; + + // Update internal data structures for later use. + resourceMap[descriptor].assign(resources.begin(), resources.end()); + canonicalResourceMap[descriptor] = canonicalResource; + for (const auto &resource : llvm::enumerate(resources)) { + descriptorMap[resource.value()] = descriptor; + elementTypeMap[resource.value()] = elementTypes[resource.index()]; + } +} + +//===----------------------------------------------------------------------===// +// Patterns +//===----------------------------------------------------------------------===// + +template +class ConvertAliasResoruce : public OpConversionPattern { +public: + ConvertAliasResoruce(const ResourceAliasAnalysis &analysis, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), analysis(analysis) {} + +protected: + const ResourceAliasAnalysis &analysis; +}; + +struct ConvertVariable : public ConvertAliasResoruce { + using ConvertAliasResoruce::ConvertAliasResoruce; + + LogicalResult + matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Just remove the aliased resource. Users will be rewritten to use the + // canonical one. + rewriter.eraseOp(varOp); + return success(); + } +}; + +struct ConvertAddressOf : public ConvertAliasResoruce { + using ConvertAliasResoruce::ConvertAliasResoruce; + + LogicalResult + matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Rewrite the AddressOf op to get the address of the canoncical resource. + auto moduleOp = addressOp->getParentOfType(); + auto srcVarOp = cast( + SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable())); + auto dstVarOp = analysis.getCanonicalResource(srcVarOp); + rewriter.replaceOpWithNewOp(addressOp, dstVarOp); + return success(); + } +}; + +struct ConvertAccessChain : public ConvertAliasResoruce { + using ConvertAliasResoruce::ConvertAliasResoruce; + + LogicalResult + matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto addressOp = acOp.base_ptr().getDefiningOp(); + if (!addressOp) + return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op"); + + auto moduleOp = acOp->getParentOfType(); + auto srcVarOp = cast( + SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable())); + auto dstVarOp = analysis.getCanonicalResource(srcVarOp); + + spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp); + spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp); + + if ((srcElemType == dstElemType) || + (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat())) { + // We have the same bitwidth for source and destination element types. + // Thie indices keep the same. + rewriter.replaceOpWithNewOp( + acOp, adaptor.base_ptr(), adaptor.indices()); + return success(); + } + + Location loc = acOp.getLoc(); + auto i32Type = rewriter.getI32Type(); + + if (srcElemType.isIntOrFloat() && dstElemType.isa()) { + // The source indices are for a buffer with scalar element types. Rewrite + // them into a buffer with vector element types. We need to scale the last + // index for the vector as a whole, then add one level of index for inside + // the vector. + int ratio = *dstElemType.getSizeInBytes() / *srcElemType.getSizeInBytes(); + auto ratioValue = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(ratio)); + + auto indices = llvm::to_vector<4>(acOp.indices()); + Value oldIndex = indices.back(); + indices.back() = + rewriter.create(loc, i32Type, oldIndex, ratioValue); + indices.push_back( + rewriter.create(loc, i32Type, oldIndex, ratioValue)); + + rewriter.replaceOpWithNewOp( + acOp, adaptor.base_ptr(), indices); + return success(); + } + + return rewriter.notifyMatchFailure(acOp, "unsupported src/dst types"); + } +}; + +struct ConvertLoad : public ConvertAliasResoruce { + using ConvertAliasResoruce::ConvertAliasResoruce; + + LogicalResult + matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcElemType = + loadOp.ptr().getType().cast().getPointeeType(); + auto dstElemType = + adaptor.ptr().getType().cast().getPointeeType(); + if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat()) + return rewriter.notifyMatchFailure(loadOp, "not scalar type"); + + Location loc = loadOp.getLoc(); + auto newLoadOp = rewriter.create(loc, adaptor.ptr()); + if (srcElemType == dstElemType) { + rewriter.replaceOp(loadOp, newLoadOp->getResults()); + } else { + auto castOp = rewriter.create(loc, srcElemType, + newLoadOp.value()); + rewriter.replaceOp(loadOp, castOp->getResults()); + } + + return success(); + } +}; + +struct ConvertStore : public ConvertAliasResoruce { + using ConvertAliasResoruce::ConvertAliasResoruce; + + LogicalResult + matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcElemType = + storeOp.ptr().getType().cast().getPointeeType(); + auto dstElemType = + adaptor.ptr().getType().cast().getPointeeType(); + if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat()) + return rewriter.notifyMatchFailure(storeOp, "not scalar type"); + + Location loc = storeOp.getLoc(); + Value value = adaptor.value(); + if (srcElemType != dstElemType) + value = rewriter.create(loc, dstElemType, value); + rewriter.replaceOpWithNewOp(storeOp, adaptor.ptr(), value, + storeOp->getAttrs()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass +//===----------------------------------------------------------------------===// + +namespace { +class UnifyAliasedResourcePass final + : public SPIRVUnifyAliasedResourcePassBase { +public: + void runOnOperation() override; +}; +} // namespace + +void UnifyAliasedResourcePass::runOnOperation() { + spirv::ModuleOp moduleOp = getOperation(); + MLIRContext *context = &getContext(); + + // Analyze aliased resources first. + ResourceAliasAnalysis &analysis = getAnalysis(); + + ConversionTarget target(*context); + target.addDynamicallyLegalOp( + [&analysis](Operation *op) { return !analysis.shouldUnify(op); }); + target.addLegalDialect(); + + // Run patterns to rewrite usages of non-canonical resources. + RewritePatternSet patterns(context); + patterns.add(analysis, context); + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) + return signalPassFailure(); + + // Drop aliased attribute if we only have one single bound resource for a + // descriptor. We need to re-collect the map here given in the above the + // conversion is best effort; certain sets may not be converted. + AliasedResourceMap resourceMap = + collectAliasedResources(cast(moduleOp)); + for (const auto &dr : resourceMap) { + const auto &resources = dr.second; + if (resources.size() == 1) + resources.front()->removeAttr("aliased"); + } +} + +std::unique_ptr> +spirv::createUnifyAliasedResourcePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir @@ -0,0 +1,215 @@ +// RUN: mlir-opt -split-input-file -spirv-unify-aliased-resource %s -o - | FileCheck %s + +spv.module Logical GLSL450 { + spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr, stride=16> [0])>, StorageBuffer> + + spv.func @load_store_scalar(%index: i32) -> f32 "None" { + %c0 = spv.Constant 0 : i32 + %addr = spv.mlir.addressof @var01s : !spv.ptr [0])>, StorageBuffer> + %ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + %value = spv.Load "StorageBuffer" %ac : f32 + spv.Store "StorageBuffer" %ac, %value : f32 + spv.ReturnValue %value : f32 + } +} + +// CHECK-LABEL: spv.module + +// CHECK-NOT: @var01s +// CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr, stride=16> [0])>, StorageBuffer> +// CHECK-NOT: @var01s + +// CHECK: spv.func @load_store_scalar(%[[INDEX:.+]]: i32) +// CHECK-DAG: %[[C0:.+]] = spv.Constant 0 : i32 +// CHECK-DAG: %[[C4:.+]] = spv.Constant 4 : i32 +// CHECK-DAG: %[[ADDR:.+]] = spv.mlir.addressof @var01v +// CHECK: %[[DIV:.+]] = spv.SDiv %[[INDEX]], %[[C4]] : i32 +// CHECK: %[[MOD:.+]] = spv.SMod %[[INDEX]], %[[C4]] : i32 +// CHECK: %[[AC:.+]] = spv.AccessChain %[[ADDR]][%[[C0]], %[[DIV]], %[[MOD]]] +// CHECK: spv.Load "StorageBuffer" %[[AC]] +// CHECK: spv.Store "StorageBuffer" %[[AC]] + +// ----- + +spv.module Logical GLSL450 { + spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr, stride=16> [0])>, StorageBuffer> + + spv.func @multiple_uses(%i0: i32, %i1: i32) -> f32 "None" { + %c0 = spv.Constant 0 : i32 + %addr = spv.mlir.addressof @var01s : !spv.ptr [0])>, StorageBuffer> + %ac0 = spv.AccessChain %addr[%c0, %i0] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + %val0 = spv.Load "StorageBuffer" %ac0 : f32 + %ac1 = spv.AccessChain %addr[%c0, %i1] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + %val1 = spv.Load "StorageBuffer" %ac1 : f32 + %value = spv.FAdd %val0, %val1 : f32 + spv.ReturnValue %value : f32 + } +} + +// CHECK-LABEL: spv.module + +// CHECK-NOT: @var01s +// CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr, stride=16> [0])>, StorageBuffer> +// CHECK-NOT: @var01s + +// CHECK: spv.func @multiple_uses +// CHECK: %[[ADDR:.+]] = spv.mlir.addressof @var01v +// CHECK: spv.AccessChain %[[ADDR]][%{{.+}}, %{{.+}}, %{{.+}}] +// CHECK: spv.AccessChain %[[ADDR]][%{{.+}}, %{{.+}}, %{{.+}}] + +// ----- + +spv.module Logical GLSL450 { + spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr, stride=16> [0])>, StorageBuffer> + + spv.func @vector3(%index: i32) -> f32 "None" { + %c0 = spv.Constant 0 : i32 + %addr = spv.mlir.addressof @var01s : !spv.ptr [0])>, StorageBuffer> + %ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + %value = spv.Load "StorageBuffer" %ac : f32 + spv.ReturnValue %value : f32 + } +} + +// CHECK-LABEL: spv.module + +// CHECK: spv.GlobalVariable @var01s bind(0, 1) {aliased} +// CHECK: spv.GlobalVariable @var01v bind(0, 1) {aliased} +// CHECK: spv.func @vector3 + +// ----- + +spv.module Logical GLSL450 { + spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + spv.GlobalVariable @var01v bind(1, 0) {aliased} : !spv.ptr, stride=16> [0])>, StorageBuffer> + + spv.func @not_aliased(%index: i32) -> f32 "None" { + %c0 = spv.Constant 0 : i32 + %addr = spv.mlir.addressof @var01s : !spv.ptr [0])>, StorageBuffer> + %ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + %value = spv.Load "StorageBuffer" %ac : f32 + spv.Store "StorageBuffer" %ac, %value : f32 + spv.ReturnValue %value : f32 + } +} + +// CHECK-LABEL: spv.module + +// CHECK: spv.GlobalVariable @var01s bind(0, 1) : !spv.ptr [0])>, StorageBuffer> +// CHECK: spv.GlobalVariable @var01v bind(1, 0) : !spv.ptr, stride=16> [0])>, StorageBuffer> +// CHECK: spv.func @not_aliased + +// ----- + +spv.module Logical GLSL450 { + spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + spv.GlobalVariable @var01s_1 bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr, stride=16> [0])>, StorageBuffer> + spv.GlobalVariable @var01v_1 bind(0, 1) {aliased} : !spv.ptr, stride=16> [0])>, StorageBuffer> + + spv.func @multiple_aliases(%index: i32) -> f32 "None" { + %c0 = spv.Constant 0 : i32 + + %addr0 = spv.mlir.addressof @var01s : !spv.ptr [0])>, StorageBuffer> + %ac0 = spv.AccessChain %addr0[%c0, %index] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + %val0 = spv.Load "StorageBuffer" %ac0 : f32 + + %addr1 = spv.mlir.addressof @var01s_1 : !spv.ptr [0])>, StorageBuffer> + %ac1 = spv.AccessChain %addr1[%c0, %index] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + %val1 = spv.Load "StorageBuffer" %ac1 : f32 + + %addr2 = spv.mlir.addressof @var01v_1 : !spv.ptr, stride=16> [0])>, StorageBuffer> + %ac2 = spv.AccessChain %addr2[%c0, %index, %c0] : !spv.ptr, stride=16> [0])>, StorageBuffer>, i32, i32, i32 + %val2 = spv.Load "StorageBuffer" %ac2 : f32 + + %add0 = spv.FAdd %val0, %val1 : f32 + %add1 = spv.FAdd %add0, %val2 : f32 + spv.ReturnValue %add1 : f32 + } +} + +// CHECK-LABEL: spv.module + +// CHECK-NOT: @var01s +// CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr, stride=16> [0])>, StorageBuffer> +// CHECK-NOT: @var01v_1 + +// CHECK: spv.func @multiple_aliases +// CHECK: %[[ADDR0:.+]] = spv.mlir.addressof @var01v : +// CHECK: spv.AccessChain %[[ADDR0]][%{{.+}}, %{{.+}}, %{{.+}}] +// CHECK: %[[ADDR1:.+]] = spv.mlir.addressof @var01v : +// CHECK: spv.AccessChain %[[ADDR1]][%{{.+}}, %{{.+}}, %{{.+}}] +// CHECK: %[[ADDR2:.+]] = spv.mlir.addressof @var01v : +// CHECK: spv.AccessChain %[[ADDR2]][%{{.+}}, %{{.+}}, %{{.+}}] + +// ----- + +spv.module Logical GLSL450 { + spv.GlobalVariable @var01s_i32 bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + spv.GlobalVariable @var01s_f32 bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + + spv.func @different_scalar_type(%index: i32, %val1: f32) -> i32 "None" { + %c0 = spv.Constant 0 : i32 + + %addr0 = spv.mlir.addressof @var01s_i32 : !spv.ptr [0])>, StorageBuffer> + %ac0 = spv.AccessChain %addr0[%c0, %index] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + %val0 = spv.Load "StorageBuffer" %ac0 : i32 + + %addr1 = spv.mlir.addressof @var01s_f32 : !spv.ptr [0])>, StorageBuffer> + %ac1 = spv.AccessChain %addr1[%c0, %index] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + spv.Store "StorageBuffer" %ac1, %val1 : f32 + + spv.ReturnValue %val0 : i32 + } +} + +// CHECK-LABEL: spv.module + +// CHECK-NOT: @var01s_f32 +// CHECK: spv.GlobalVariable @var01s_i32 bind(0, 1) : !spv.ptr [0])>, StorageBuffer> +// CHECK-NOT: @var01s_f32 + +// CHECK: spv.func @different_scalar_type(%[[INDEX:.+]]: i32, %[[VAL1:.+]]: f32) + +// CHECK: %[[IADDR:.+]] = spv.mlir.addressof @var01s_i32 +// CHECK: %[[IAC:.+]] = spv.AccessChain %[[IADDR]][%{{.+}}, %[[INDEX]]] +// CHECK: spv.Load "StorageBuffer" %[[IAC]] : i32 + +// CHECK: %[[FADDR:.+]] = spv.mlir.addressof @var01s_i32 +// CHECK: %[[FAC:.+]] = spv.AccessChain %[[FADDR]][%cst0_i32, %[[INDEX]]] +// CHECK: %[[CAST:.+]] = spv.Bitcast %[[VAL1]] : f32 to i32 +// CHECK: spv.Store "StorageBuffer" %[[FAC]], %[[CAST]] : i32 + +// ----- + +spv.module Logical GLSL450 { + spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr, stride=16> [0])>, StorageBuffer> + + spv.func @different_scalar_type(%index: i32, %val0: i32) -> i32 "None" { + %c0 = spv.Constant 0 : i32 + %addr = spv.mlir.addressof @var01s : !spv.ptr [0])>, StorageBuffer> + %ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + %val1 = spv.Load "StorageBuffer" %ac : i32 + spv.Store "StorageBuffer" %ac, %val0 : i32 + spv.ReturnValue %val1 : i32 + } +} + +// CHECK-LABEL: spv.module + +// CHECK-NOT: @var01s +// CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr, stride=16> [0])>, StorageBuffer> +// CHECK-NOT: @var01s + +// CHECK: spv.func @different_scalar_type(%{{.+}}: i32, %[[VAL0:.+]]: i32) +// CHECK: %[[ADDR:.+]] = spv.mlir.addressof @var01v +// CHECK: %[[AC:.+]] = spv.AccessChain %[[ADDR]][%{{.+}}, %{{.+}}, %{{.+}}] +// CHECK: %[[VAL1:.+]] = spv.Load "StorageBuffer" %[[AC]] : f32 +// CHECK: %[[CAST1:.+]] = spv.Bitcast %[[VAL1]] : f32 to i32 +// CHECK: %[[CAST2:.+]] = spv.Bitcast %[[VAL0]] : i32 to f32 +// CHECK: spv.Store "StorageBuffer" %[[AC]], %[[CAST2]] : f32 +// CHECK: spv.ReturnValue %[[CAST1]] : i32