diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp @@ -26,6 +26,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include +#include #define DEBUG_TYPE "spirv-unify-aliased-resource" @@ -72,20 +73,65 @@ return rtArrayType.getElementType(); } -/// Returns true if all `types`, which can either be scalar or vector types, -/// have the same bitwidth base scalar type. -static bool hasSameBitwidthScalarType(ArrayRef types) { - SmallVector scalarTypes; - scalarTypes.reserve(types.size()); +/// Given a list of resource element `types`, returns the index of the canonical +/// resource that all resources should be unified into. Returns llvm::None if +/// unable to unify. +static Optional deduceCanonicalResource(ArrayRef types) { + SmallVector scalarNumBits, totalNumBits; + scalarNumBits.reserve(types.size()); + totalNumBits.reserve(types.size()); + bool hasVector = false; + for (spirv::SPIRVType type : types) { assert(type.isScalarOrVector()); - if (auto vectorType = type.dyn_cast()) - scalarTypes.push_back( + if (auto vectorType = type.dyn_cast()) { + if (vectorType.getNumElements() % 2 != 0) + return llvm::None; // Odd-sized vector has special layout requirements. + + Optional numBytes = type.getSizeInBytes(); + if (!numBytes) + return llvm::None; + + scalarNumBits.push_back( vectorType.getElementType().getIntOrFloatBitWidth()); - else - scalarTypes.push_back(type.getIntOrFloatBitWidth()); + totalNumBits.push_back(*numBytes * 8); + hasVector = true; + } else { + scalarNumBits.push_back(type.getIntOrFloatBitWidth()); + totalNumBits.push_back(scalarNumBits.back()); + } } - return llvm::is_splat(scalarTypes); + + if (hasVector) { + // If there are vector types, require all element types to be the same for + // now to simplify the transformation. + if (!llvm::is_splat(scalarNumBits)) + return llvm::None; + + // Choose the one with the largest bitwidth as the canonical resource, so + // that we can still keep vectorized load/store. + auto *maxVal = std::max_element(totalNumBits.begin(), totalNumBits.end()); + // Make sure that the canonical resource's bitwidth is divisible by others. + // With out this, we cannot properly adjust the index later. + if (llvm::any_of(totalNumBits, + [maxVal](int64_t bits) { return *maxVal % bits != 0; })) + return llvm::None; + + return std::distance(totalNumBits.begin(), maxVal); + } + + // All element types are scalars. Then choose the smallest bitwidth as the + // cannonical resource to avoid subcomponent load/store. + auto *minVal = std::min_element(scalarNumBits.begin(), scalarNumBits.end()); + if (llvm::any_of(scalarNumBits, + [minVal](int64_t bit) { return bit % *minVal != 0; })) + return llvm::None; + return std::distance(scalarNumBits.begin(), minVal); +} + +static bool areSameBitwidthScalarType(Type a, Type b) { + return a.isIntOrFloat() && b.isIntOrFloat() && + a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth(); } //===----------------------------------------------------------------------===// @@ -203,11 +249,8 @@ void ResourceAliasAnalysis::recordIfUnifiable( const Descriptor &descriptor, ArrayRef resources) { - // Collect the element types and byte counts for all resources in the - // current set. + // Collect the element types for all resources in the current set. SmallVector elementTypes; - SmallVector numBytes; - for (spirv::GlobalVariableOp resource : resources) { Type elementType = getRuntimeArrayElementType(resource.type()); if (!elementType) @@ -217,37 +260,16 @@ if (!type.isScalarOrVector()) return; // Unexpected resource element type. - if (auto vectorType = type.dyn_cast()) - if (vectorType.getNumElements() % 2 != 0) - return; // Odd-sized vector has special layout requirements. - - Optional count = type.getSizeInBytes(); - if (!count) - return; - elementTypes.push_back(type); - numBytes.push_back(*count); } - // Make sure base scalar types have the same bitwdith, so that we don't need - // to handle extracting components for now. - if (!hasSameBitwidthScalarType(elementTypes)) - return; - - // Make sure that the canonical resource's bitwidth is divisible by others. - // With out this, we cannot properly adjust the index later. - auto *maxCount = std::max_element(numBytes.begin(), numBytes.end()); - if (llvm::any_of(numBytes, [maxCount](int64_t count) { - return *maxCount % count != 0; - })) + Optional index = deduceCanonicalResource(elementTypes); + if (!index) return; - spirv::GlobalVariableOp canonicalResource = - resources[std::distance(numBytes.begin(), maxCount)]; - // Update internal data structures for later use. resourceMap[descriptor].assign(resources.begin(), resources.end()); - canonicalResourceMap[descriptor] = canonicalResource; + canonicalResourceMap[descriptor] = resources[*index]; for (const auto &resource : llvm::enumerate(resources)) { descriptorMap[resource.value()] = descriptor; elementTypeMap[resource.value()] = elementTypes[resource.index()]; @@ -316,8 +338,8 @@ spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp); spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp); - if ((srcElemType == dstElemType) || - (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat())) { + if (srcElemType == dstElemType || + areSameBitwidthScalarType(srcElemType, dstElemType)) { // We have the same bitwidth for source and destination element types. // Thie indices keep the same. rewriter.replaceOpWithNewOp( @@ -333,7 +355,10 @@ // them into a buffer with vector element types. We need to scale the last // index for the vector as a whole, then add one level of index for inside // the vector. - int ratio = *dstElemType.getSizeInBytes() / *srcElemType.getSizeInBytes(); + int srcNumBits = *srcElemType.getSizeInBytes(); + int dstNumBits = *dstElemType.getSizeInBytes(); + assert(dstNumBits > srcNumBits && dstNumBits % srcNumBits == 0); + int ratio = dstNumBits / srcNumBits; auto ratioValue = rewriter.create( loc, i32Type, rewriter.getI32IntegerAttr(ratio)); @@ -349,6 +374,27 @@ return success(); } + if (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) { + // The source indices are for a buffer with larger bitwidth scalar element + // types. Rewrite them into a buffer with smaller bitwidth element types. + // We only need to scale the last index. + int srcNumBits = *srcElemType.getSizeInBytes(); + int dstNumBits = *dstElemType.getSizeInBytes(); + assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0); + int ratio = srcNumBits / dstNumBits; + auto ratioValue = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(ratio)); + + auto indices = llvm::to_vector<4>(acOp.indices()); + Value oldIndex = indices.back(); + indices.back() = + rewriter.create(loc, i32Type, oldIndex, ratioValue); + + rewriter.replaceOpWithNewOp( + acOp, adaptor.base_ptr(), indices); + return success(); + } + return rewriter.notifyMatchFailure(acOp, "unsupported src/dst types"); } }; @@ -370,12 +416,56 @@ auto newLoadOp = rewriter.create(loc, adaptor.ptr()); if (srcElemType == dstElemType) { rewriter.replaceOp(loadOp, newLoadOp->getResults()); - } else { + return success(); + } + + if (areSameBitwidthScalarType(srcElemType, dstElemType)) { auto castOp = rewriter.create(loc, srcElemType, newLoadOp.value()); rewriter.replaceOp(loadOp, castOp->getResults()); + + return success(); } + // The source and destination have scalar types of different bitwidths. + // For such cases, we need to load multiple smaller bitwidth values and + // construct a larger bitwidth one. + + int srcNumBits = srcElemType.getIntOrFloatBitWidth(); + int dstNumBits = dstElemType.getIntOrFloatBitWidth(); + assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0); + int ratio = srcNumBits / dstNumBits; + if (ratio > 4) + return rewriter.notifyMatchFailure(loadOp, "more than 4 components"); + + SmallVector components; + components.reserve(ratio); + components.push_back(newLoadOp); + + auto acOp = adaptor.ptr().getDefiningOp(); + if (!acOp) + return rewriter.notifyMatchFailure(loadOp, "ptr not spv.AccessChain"); + + auto i32Type = rewriter.getI32Type(); + Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter); + auto indices = llvm::to_vector<4>(acOp.indices()); + for (int i = 1; i < ratio; ++i) { + // Load all subsequent components belonging to this element. + indices.back() = rewriter.create(loc, i32Type, + indices.back(), oneValue); + auto componentAcOp = + rewriter.create(loc, acOp.base_ptr(), indices); + components.push_back(rewriter.create(loc, componentAcOp)); + } + std::reverse(components.begin(), components.end()); // For little endian.. + + // Create a vector of the components and then cast back to the larger + // bitwidth element type. + auto vectorType = VectorType::get({ratio}, dstElemType); + Value vectorValue = rewriter.create( + loc, vectorType, components); + rewriter.replaceOpWithNewOp(loadOp, srcElemType, + vectorValue); return success(); } }; @@ -392,6 +482,8 @@ adaptor.ptr().getType().cast().getPointeeType(); if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat()) return rewriter.notifyMatchFailure(storeOp, "not scalar type"); + if (!areSameBitwidthScalarType(srcElemType, dstElemType)) + return rewriter.notifyMatchFailure(storeOp, "different bitwidth"); Location loc = storeOp.getLoc(); Value value = adaptor.value(); diff --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -spirv-unify-aliased-resource %s -o - | FileCheck %s +// RUN: mlir-opt -split-input-file -spirv-unify-aliased-resource -verify-diagnostics %s | FileCheck %s spv.module Logical GLSL450 { spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> @@ -213,3 +213,68 @@ // CHECK: %[[CAST2:.+]] = spv.Bitcast %[[VAL0]] : i32 to f32 // CHECK: spv.Store "StorageBuffer" %[[AC]], %[[CAST2]] : f32 // CHECK: spv.ReturnValue %[[CAST1]] : i32 + +// ----- + +spv.module Logical GLSL450 { + spv.GlobalVariable @var01s_i64 bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + spv.GlobalVariable @var01s_f32 bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + + spv.func @load_different_scalar_bitwidth(%index: i32) -> i64 "None" { + %c0 = spv.Constant 0 : i32 + + %addr0 = spv.mlir.addressof @var01s_i64 : !spv.ptr [0])>, StorageBuffer> + %ac0 = spv.AccessChain %addr0[%c0, %index] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + %val0 = spv.Load "StorageBuffer" %ac0 : i64 + + spv.ReturnValue %val0 : i64 + } +} + +// CHECK-LABEL: spv.module + +// CHECK-NOT: @var01s_i64 +// CHECK: spv.GlobalVariable @var01s_f32 bind(0, 1) : !spv.ptr [0])>, StorageBuffer> +// CHECK-NOT: @var01s_i64 + +// CHECK: spv.func @load_different_scalar_bitwidth(%[[INDEX:.+]]: i32) +// CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 +// CHECK: %[[ADDR:.+]] = spv.mlir.addressof @var01s_f32 + +// CHECK: %[[TWO:.+]] = spv.Constant 2 : i32 +// CHECK: %[[BASE:.+]] = spv.IMul %[[INDEX]], %[[TWO]] : i32 +// CHECK: %[[AC0:.+]] = spv.AccessChain %[[ADDR]][%[[ZERO]], %[[BASE]]] +// CHECK: %[[LOAD0:.+]] = spv.Load "StorageBuffer" %[[AC0]] : f32 + +// CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 +// CHECK: %[[ADD:.+]] = spv.IAdd %[[BASE]], %[[ONE]] : i32 +// CHECK: %[[AC1:.+]] = spv.AccessChain %[[ADDR]][%[[ZERO]], %[[ADD]]] +// CHECK: %[[LOAD1:.+]] = spv.Load "StorageBuffer" %[[AC1]] : f32 + +// CHECK: %[[CC:.+]] = spv.CompositeConstruct %[[LOAD1]], %[[LOAD0]] +// CHECK: %[[CAST:.+]] = spv.Bitcast %[[CC]] : vector<2xf32> to i64 +// CHECK: spv.ReturnValue %[[CAST]] + +// ----- + +spv.module Logical GLSL450 { + spv.GlobalVariable @var01s_i64 bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + spv.GlobalVariable @var01s_f32 bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + + spv.func @store_different_scalar_bitwidth(%i0: i32, %i1: i32) "None" { + %c0 = spv.Constant 0 : i32 + + %addr0 = spv.mlir.addressof @var01s_f32 : !spv.ptr [0])>, StorageBuffer> + %ac0 = spv.AccessChain %addr0[%c0, %i0] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + %f32val = spv.Load "StorageBuffer" %ac0 : f32 + %f64val = spv.FConvert %f32val : f32 to f64 + %i64val = spv.Bitcast %f64val : f64 to i64 + + %addr1 = spv.mlir.addressof @var01s_i64 : !spv.ptr [0])>, StorageBuffer> + %ac1 = spv.AccessChain %addr1[%c0, %i1] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + // expected-error@+1 {{failed to legalize operation 'spv.Store'}} + spv.Store "StorageBuffer" %ac1, %i64val : i64 + + spv.Return + } +}