diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp @@ -20,7 +20,6 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/SymbolTable.h" -#include "mlir/Pass/AnalysisManager.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" @@ -77,12 +76,15 @@ /// resource that all resources should be unified into. Returns llvm::None if /// unable to unify. static Optional deduceCanonicalResource(ArrayRef types) { - SmallVector scalarNumBits, totalNumBits; + // scalarNumBits: contains all resources' scalar types' bit counts. + // vectorNumBits: only contains resources whose element types are vectors. + SmallVector scalarNumBits, vectorNumBits, vectorIndices; scalarNumBits.reserve(types.size()); - totalNumBits.reserve(types.size()); - bool hasVector = false; + vectorNumBits.reserve(types.size()); + vectorIndices.reserve(types.size()); - for (spirv::SPIRVType type : types) { + for (const auto &indexedTypes : llvm::enumerate(types)) { + spirv::SPIRVType type = indexedTypes.value(); assert(type.isScalarOrVector()); if (auto vectorType = type.dyn_cast()) { if (vectorType.getNumElements() % 2 != 0) @@ -94,30 +96,30 @@ scalarNumBits.push_back( vectorType.getElementType().getIntOrFloatBitWidth()); - totalNumBits.push_back(*numBytes * 8); - hasVector = true; + vectorNumBits.push_back(*numBytes * 8); + vectorIndices.push_back(indexedTypes.index()); } else { scalarNumBits.push_back(type.getIntOrFloatBitWidth()); - totalNumBits.push_back(scalarNumBits.back()); } } - if (hasVector) { + if (!vectorNumBits.empty()) { // If there are vector types, require all element types to be the same for // now to simplify the transformation. if (!llvm::is_splat(scalarNumBits)) return llvm::None; - // Choose the one with the largest bitwidth as the canonical resource, so - // that we can still keep vectorized load/store. - auto *maxVal = std::max_element(totalNumBits.begin(), totalNumBits.end()); + // Choose the *vector* with the smallest bitwidth as the canonical resource, + // so that we can still keep vectorized load/store and avoid partial updates + // to large vectors. + auto *minVal = std::min_element(vectorNumBits.begin(), vectorNumBits.end()); // Make sure that the canonical resource's bitwidth is divisible by others. // With out this, we cannot properly adjust the index later. - if (llvm::any_of(totalNumBits, - [maxVal](int64_t bits) { return *maxVal % bits != 0; })) + if (llvm::any_of(vectorNumBits, + [minVal](int64_t bits) { return bits % *minVal != 0; })) return llvm::None; - return std::distance(totalNumBits.begin(), maxVal); + return vectorIndices[std::distance(vectorNumBits.begin(), minVal)]; } // All element types are scalars. Then choose the smallest bitwidth as the @@ -374,10 +376,11 @@ return success(); } - if (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) { - // The source indices are for a buffer with larger bitwidth scalar element - // types. Rewrite them into a buffer with smaller bitwidth element types. - // We only need to scale the last index. + if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) || + (srcElemType.isa() && dstElemType.isa())) { + // The source indices are for a buffer with larger bitwidth scalar/vector + // element types. Rewrite them into a buffer with smaller bitwidth element + // types. We only need to scale the last index. int srcNumBits = *srcElemType.getSizeInBytes(); int dstNumBits = *dstElemType.getSizeInBytes(); assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0); @@ -395,7 +398,8 @@ return success(); } - return rewriter.notifyMatchFailure(acOp, "unsupported src/dst types"); + return rewriter.notifyMatchFailure( + acOp, "unsupported src/dst types for spv.AccessChain"); } }; @@ -405,12 +409,10 @@ LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto srcElemType = - loadOp.ptr().getType().cast().getPointeeType(); - auto dstElemType = - adaptor.ptr().getType().cast().getPointeeType(); - if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat()) - return rewriter.notifyMatchFailure(loadOp, "not scalar type"); + auto srcPtrType = loadOp.ptr().getType().cast(); + auto srcElemType = srcPtrType.getPointeeType().cast(); + auto dstPtrType = adaptor.ptr().getType().cast(); + auto dstElemType = dstPtrType.getPointeeType().cast(); Location loc = loadOp.getLoc(); auto newLoadOp = rewriter.create(loc, adaptor.ptr()); @@ -427,48 +429,60 @@ return success(); } - // The source and destination have scalar types of different bitwidths. - // For such cases, we need to load multiple smaller bitwidth values and - // construct a larger bitwidth one. + if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) || + (srcElemType.isa() && dstElemType.isa())) { + // The source and destination have scalar types of different bitwidths, or + // vector types of different component counts. For such cases, we load + // multiple smaller bitwidth values and construct a larger bitwidth one. - int srcNumBits = srcElemType.getIntOrFloatBitWidth(); - int dstNumBits = dstElemType.getIntOrFloatBitWidth(); - assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0); - int ratio = srcNumBits / dstNumBits; - if (ratio > 4) - return rewriter.notifyMatchFailure(loadOp, "more than 4 components"); + int srcNumBits = *srcElemType.getSizeInBytes() * 8; + int dstNumBits = *dstElemType.getSizeInBytes() * 8; + assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0); + int ratio = srcNumBits / dstNumBits; + if (ratio > 4) + return rewriter.notifyMatchFailure(loadOp, "more than 4 components"); - SmallVector components; - components.reserve(ratio); - components.push_back(newLoadOp); + SmallVector components; + components.reserve(ratio); + components.push_back(newLoadOp); - auto acOp = adaptor.ptr().getDefiningOp(); - if (!acOp) - return rewriter.notifyMatchFailure(loadOp, "ptr not spv.AccessChain"); + auto acOp = adaptor.ptr().getDefiningOp(); + if (!acOp) + return rewriter.notifyMatchFailure(loadOp, "ptr not spv.AccessChain"); - auto i32Type = rewriter.getI32Type(); - Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter); - auto indices = llvm::to_vector<4>(acOp.indices()); - for (int i = 1; i < ratio; ++i) { - // Load all subsequent components belonging to this element. - indices.back() = rewriter.create(loc, i32Type, - indices.back(), oneValue); - auto componentAcOp = - rewriter.create(loc, acOp.base_ptr(), indices); - // Assuming little endian, this reads lower-ordered bits of the number to - // lower-numbered components of the vector. - components.push_back(rewriter.create(loc, componentAcOp)); + auto i32Type = rewriter.getI32Type(); + Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter); + auto indices = llvm::to_vector<4>(acOp.indices()); + for (int i = 1; i < ratio; ++i) { + // Load all subsequent components belonging to this element. + indices.back() = rewriter.create( + loc, i32Type, indices.back(), oneValue); + auto componentAcOp = rewriter.create( + loc, acOp.base_ptr(), indices); + // Assuming little endian, this reads lower-ordered bits of the number + // to lower-numbered components of the vector. + components.push_back( + rewriter.create(loc, componentAcOp)); + } + + // Create a vector of the components and then cast back to the larger + // bitwidth element type. For spv.bitcast, the lower-numbered components + // of the vector map to lower-ordered bits of the larger bitwidth element + // type. + Type vectorType = srcElemType; + if (!srcElemType.isa()) + vectorType = VectorType::get({ratio}, dstElemType); + Value vectorValue = rewriter.create( + loc, vectorType, components); + if (!srcElemType.isa()) + vectorValue = + rewriter.create(loc, srcElemType, vectorValue); + rewriter.replaceOp(loadOp, vectorValue); + return success(); } - // Create a vector of the components and then cast back to the larger - // bitwidth element type. For spv.bitcast, the lower-numbered components of - // the vector map to lower-ordered bits of the larger bitwidth element type. - auto vectorType = VectorType::get({ratio}, dstElemType); - Value vectorValue = rewriter.create( - loc, vectorType, components); - rewriter.replaceOpWithNewOp(loadOp, srcElemType, - vectorValue); - return success(); + return rewriter.notifyMatchFailure( + loadOp, "unsupported src/dst types for spv.Load"); } }; diff --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir @@ -278,3 +278,54 @@ spv.Return } } + +// ----- + +spv.module Logical GLSL450 { + spv.GlobalVariable @var01_scalar bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + spv.GlobalVariable @var01_vec2 bind(0, 1) {aliased} : !spv.ptr, stride=8> [0])>, StorageBuffer> + spv.GlobalVariable @var01_vec4 bind(0, 1) {aliased} : !spv.ptr, stride=16> [0])>, StorageBuffer> + + spv.func @load_different_vector_sizes(%i0: i32) -> vector<4xf32> "None" { + %c0 = spv.Constant 0 : i32 + + %addr0 = spv.mlir.addressof @var01_vec4 : !spv.ptr, stride=16> [0])>, StorageBuffer> + %ac0 = spv.AccessChain %addr0[%c0, %i0] : !spv.ptr, stride=16> [0])>, StorageBuffer>, i32, i32 + %vec4val = spv.Load "StorageBuffer" %ac0 : vector<4xf32> + + %addr1 = spv.mlir.addressof @var01_scalar : !spv.ptr [0])>, StorageBuffer> + %ac1 = spv.AccessChain %addr1[%c0, %i0] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + %scalarval = spv.Load "StorageBuffer" %ac1 : f32 + + %val = spv.CompositeInsert %scalarval, %vec4val[0 : i32] : f32 into vector<4xf32> + spv.ReturnValue %val : vector<4xf32> + } +} + +// CHECK-LABEL: spv.module + +// CHECK-NOT: @var01_scalar +// CHECK-NOT: @var01_vec4 +// CHECK: spv.GlobalVariable @var01_vec2 bind(0, 1) : !spv.ptr<{{.+}}> +// CHECK-NOT: @var01_scalar +// CHECK-NOT: @var01_vec4 + +// CHECK: spv.func @load_different_vector_sizes(%[[IDX:.+]]: i32) +// CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 +// CHECK: %[[ADDR:.+]] = spv.mlir.addressof @var01_vec2 +// CHECK: %[[TWO:.+]] = spv.Constant 2 : i32 +// CHECK: %[[IDX0:.+]] = spv.IMul %[[IDX]], %[[TWO]] : i32 +// CHECK: %[[AC0:.+]] = spv.AccessChain %[[ADDR]][%[[ZERO]], %[[IDX0]]] +// CHECK: %[[LD0:.+]] = spv.Load "StorageBuffer" %[[AC0]] : vector<2xf32> +// CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 +// CHECK: %[[IDX1:.+]] = spv.IAdd %0, %[[ONE]] : i32 +// CHECK: %[[AC1:.+]] = spv.AccessChain %[[ADDR]][%[[ZERO]], %[[IDX1]]] +// CHECK: %[[LD1:.+]] = spv.Load "StorageBuffer" %[[AC1]] : vector<2xf32> +// CHECK: spv.CompositeConstruct %[[LD0]], %[[LD1]] : (vector<2xf32>, vector<2xf32>) -> vector<4xf32> + +// CHECK: %[[ADDR:.+]] = spv.mlir.addressof @var01_vec2 +// CHECK: %[[TWO:.+]] = spv.Constant 2 : i32 +// CHECK: %[[DIV:.+]] = spv.SDiv %[[IDX]], %[[TWO]] : i32 +// CHECK: %[[MOD:.+]] = spv.SMod %[[IDX]], %[[TWO]] : i32 +// CHECK: %[[AC:.+]] = spv.AccessChain %[[ADDR]][%[[ZERO]], %[[DIV]], %[[MOD]]] +// CHECK: %[[LD:.+]] = spv.Load "StorageBuffer" %[[AC]] : f32