diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md --- a/mlir/docs/Dialects/SPIR-V.md +++ b/mlir/docs/Dialects/SPIR-V.md @@ -287,13 +287,15 @@ | vector-type | spirv-type -array-type ::= `!spv.array<` integer-literal `x` element-type `>` +array-type ::= `!spv.array` `<` integer-literal `x` element-type + (`,` `stride` `=` integer-literal)? `>` ``` For example, ```mlir !spv.array<4 x i32> +!spv.array<4 x i32, stride = 4> !spv.array<16 x vector<4 x f32>> ``` @@ -351,13 +353,14 @@ This corresponds to SPIR-V [runtime array type][RuntimeArrayType]. Its syntax is ``` -runtime-array-type ::= `!spv.rtarray<` element-type `>` +runtime-array-type ::= `!spv.rtarray` `<` element-type (`,` `stride` `=` integer-literal)? `>` ``` For example, ```mlir !spv.rtarray +!spv.rtarray !spv.rtarray> ``` diff --git a/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h b/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h --- a/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h +++ b/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h @@ -18,9 +18,11 @@ namespace mlir { class Type; class VectorType; + namespace spirv { -class StructType; class ArrayType; +class RuntimeArrayType; +class StructType; } // namespace spirv /// According to the Vulkan spec "14.5.4. Offset and Stride Assignment": @@ -47,10 +49,9 @@ public: using Size = uint64_t; - /// Returns a new StructType with layout info. Assigns the type size in bytes - /// to the `size`. Assigns the type alignment in bytes to the `alignment`. - static spirv::StructType decorateType(spirv::StructType structType, - Size &size, Size &alignment); + /// Returns a new StructType with layout decoration. + static spirv::StructType decorateType(spirv::StructType structType); + /// Checks whether a type is legal in terms of Vulkan layout info /// decoration. A type is dynamically illegal if it's a composite type in the /// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant Storage @@ -58,10 +59,17 @@ static bool isLegalType(Type type); private: + /// Returns a new type with layout decoration. Assigns the type size in bytes + /// to the `size`. Assigns the type alignment in bytes to the `alignment`. static Type decorateType(Type type, Size &size, Size &alignment); + static Type decorateType(VectorType vectorType, Size &size, Size &alignment); static Type decorateType(spirv::ArrayType arrayType, Size &size, Size &alignment); + static Type decorateType(spirv::RuntimeArrayType arrayType, Size &alignment); + static spirv::StructType decorateType(spirv::StructType structType, + Size &size, Size &alignment); + /// Calculates the alignment for the given scalar type. static Size getScalarTypeAlignment(Type scalarType); }; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -147,23 +147,22 @@ detail::ArrayTypeStorage> { public: using Base::Base; - // Zero layout specifies that is no layout - using LayoutInfo = uint64_t; static bool kindof(unsigned kind) { return kind == TypeKind::Array; } static ArrayType get(Type elementType, unsigned elementCount); + /// Returns an array type with the given stride in bytes. static ArrayType get(Type elementType, unsigned elementCount, - LayoutInfo layoutInfo); + unsigned stride); unsigned getNumElements() const; Type getElementType() const; - bool hasLayout() const; - - uint64_t getArrayStride() const; + /// Returns the array stride in bytes. 0 means no stride decorated on this + /// type. + unsigned getArrayStride() const; void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional storage = llvm::None); @@ -243,8 +242,15 @@ static RuntimeArrayType get(Type elementType); + /// Returns a runtime array type with the given stride in bytes. + static RuntimeArrayType get(Type elementType, unsigned stride); + Type getElementType() const; + /// Returns the array stride in bytes. 0 means no stride decorated on this + /// type. + unsigned getArrayStride() const; + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional storage = llvm::None); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, diff --git a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp --- a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp +++ b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp @@ -17,6 +17,13 @@ using namespace mlir; spirv::StructType +VulkanLayoutUtils::decorateType(spirv::StructType structType) { + Size size = 0; + Size alignment = 1; + return decorateType(structType, size, alignment); +} + +spirv::StructType VulkanLayoutUtils::decorateType(spirv::StructType structType, VulkanLayoutUtils::Size &size, VulkanLayoutUtils::Size &alignment) { @@ -25,21 +32,26 @@ } SmallVector memberTypes; - SmallVector layoutInfo; + SmallVector layoutInfo; SmallVector memberDecorations; - VulkanLayoutUtils::Size structMemberOffset = 0; - VulkanLayoutUtils::Size maxMemberAlignment = 1; + Size structMemberOffset = 0; + Size maxMemberAlignment = 1; for (uint32_t i = 0, e = structType.getNumElements(); i < e; ++i) { - VulkanLayoutUtils::Size memberSize = 0; - VulkanLayoutUtils::Size memberAlignment = 1; + Size memberSize = 0; + Size memberAlignment = 1; - auto memberType = VulkanLayoutUtils::decorateType( - structType.getElementType(i), memberSize, memberAlignment); + auto memberType = + decorateType(structType.getElementType(i), memberSize, memberAlignment); structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment); memberTypes.push_back(memberType); layoutInfo.push_back(structMemberOffset); + // If the member's size is the max value, it must be the last member and it + // must be a runtime array. + assert(memberSize != std::numeric_limits().max() || + (i + 1 == e && + structType.getElementType(i).isa())); // According to the Vulkan spec: // "A structure has a base alignment equal to the largest base alignment of // any of its members." @@ -60,7 +72,7 @@ Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size, VulkanLayoutUtils::Size &alignment) { if (type.isa()) { - alignment = VulkanLayoutUtils::getScalarTypeAlignment(type); + alignment = getScalarTypeAlignment(type); // Vulkan spec does not specify any padding for a scalar type. size = alignment; return type; @@ -68,14 +80,14 @@ switch (type.getKind()) { case spirv::TypeKind::Struct: - return VulkanLayoutUtils::decorateType(type.cast(), size, - alignment); + return decorateType(type.cast(), size, alignment); case spirv::TypeKind::Array: - return VulkanLayoutUtils::decorateType(type.cast(), size, - alignment); + return decorateType(type.cast(), size, alignment); case StandardTypes::Vector: - return VulkanLayoutUtils::decorateType(type.cast(), size, - alignment); + return decorateType(type.cast(), size, alignment); + case spirv::TypeKind::RuntimeArray: + size = std::numeric_limits().max(); + return decorateType(type.cast(), alignment); default: llvm_unreachable("unhandled SPIR-V type"); } @@ -86,11 +98,10 @@ VulkanLayoutUtils::Size &alignment) { const auto numElements = vectorType.getNumElements(); auto elementType = vectorType.getElementType(); - VulkanLayoutUtils::Size elementSize = 0; - VulkanLayoutUtils::Size elementAlignment = 1; + Size elementSize = 0; + Size elementAlignment = 1; - auto memberType = VulkanLayoutUtils::decorateType(elementType, elementSize, - elementAlignment); + auto memberType = decorateType(elementType, elementSize, elementAlignment); // According to the Vulkan spec: // 1. "A two-component vector has a base alignment equal to twice its scalar // alignment." @@ -106,11 +117,10 @@ VulkanLayoutUtils::Size &alignment) { const auto numElements = arrayType.getNumElements(); auto elementType = arrayType.getElementType(); - spirv::ArrayType::LayoutInfo elementSize = 0; - VulkanLayoutUtils::Size elementAlignment = 1; + Size elementSize = 0; + Size elementAlignment = 1; - auto memberType = VulkanLayoutUtils::decorateType(elementType, elementSize, - elementAlignment); + auto memberType = decorateType(elementType, elementSize, elementAlignment); // According to the Vulkan spec: // "An array has a base alignment equal to the base alignment of its element // type." @@ -119,6 +129,15 @@ return spirv::ArrayType::get(memberType, numElements, elementSize); } +Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType, + VulkanLayoutUtils::Size &alignment) { + auto elementType = arrayType.getElementType(); + Size elementSize = 0; + + auto memberType = decorateType(elementType, elementSize, alignment); + return spirv::RuntimeArrayType::get(memberType, elementSize); +} + VulkanLayoutUtils::Size VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) { // According to the Vulkan spec: diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -149,7 +149,7 @@ DialectAsmParser &parser); template <> -Optional parseAndVerify(SPIRVDialect const &dialect, +Optional parseAndVerify(SPIRVDialect const &dialect, DialectAsmParser &parser); static Type parseAndVerifyType(SPIRVDialect const &dialect, @@ -196,13 +196,39 @@ return type; } +/// Parses an optional `, stride = N` assembly segment. If no parsing failure +/// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if +/// missing. +static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, + DialectAsmParser &parser, + unsigned &stride) { + if (failed(parser.parseOptionalComma())) { + stride = 0; + return success(); + } + + if (parser.parseKeyword("stride") || parser.parseEqual()) + return failure(); + + llvm::SMLoc strideLoc = parser.getCurrentLocation(); + Optional optStride = parseAndVerify(dialect, parser); + if (!optStride) + return failure(); + + if (!(stride = optStride.getValue())) { + parser.emitError(strideLoc, "ArrayStride must be greater than zero"); + return failure(); + } + return success(); +} + // element-type ::= integer-type // | floating-point-type // | vector-type // | spirv-type // -// array-type ::= `!spv.array<` integer-literal `x` element-type -// (`[` integer-literal `]`)? `>` +// array-type ::= `!spv.array` `<` integer-literal `x` element-type +// (`,` `stride` `=` integer-literal)? `>` static Type parseArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser) { if (parser.parseLess()) @@ -230,25 +256,13 @@ if (!elementType) return Type(); - ArrayType::LayoutInfo layoutInfo = 0; - if (succeeded(parser.parseOptionalLSquare())) { - llvm::SMLoc layoutLoc = parser.getCurrentLocation(); - auto layout = parseAndVerify(dialect, parser); - if (!layout) - return Type(); - - if (!(layoutInfo = layout.getValue())) { - parser.emitError(layoutLoc, "ArrayStride must be greater than zero"); - return Type(); - } - - if (parser.parseRSquare()) - return Type(); - } + unsigned stride = 0; + if (failed(parseOptionalArrayStride(dialect, parser, stride))) + return Type(); if (parser.parseGreater()) return Type(); - return ArrayType::get(elementType, count, layoutInfo); + return ArrayType::get(elementType, count, stride); } // TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type @@ -285,7 +299,8 @@ return PointerType::get(pointeeType, *storageClass); } -// runtime-array-type ::= `!spv.rtarray<` element-type `>` +// runtime-array-type ::= `!spv.rtarray` `<` element-type +// (`,` `stride` `=` integer-literal)? `>` static Type parseRuntimeArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser) { if (parser.parseLess()) @@ -295,9 +310,13 @@ if (!elementType) return Type(); + unsigned stride = 0; + if (failed(parseOptionalArrayStride(dialect, parser, stride))) + return Type(); + if (parser.parseGreater()) return Type(); - return RuntimeArrayType::get(elementType); + return RuntimeArrayType::get(elementType, stride); } // Specialize this function to parse each of the parameters that define an @@ -337,9 +356,9 @@ } template <> -Optional parseAndVerify(SPIRVDialect const &dialect, +Optional parseAndVerify(SPIRVDialect const &dialect, DialectAsmParser &parser) { - return parseAndVerifyInteger(dialect, parser); + return parseAndVerifyInteger(dialect, parser); } namespace { @@ -526,14 +545,16 @@ static void print(ArrayType type, DialectAsmPrinter &os) { os << "array<" << type.getNumElements() << " x " << type.getElementType(); - if (type.hasLayout()) { - os << " [" << type.getArrayStride() << "]"; - } + if (unsigned stride = type.getArrayStride()) + os << ", stride=" << stride; os << ">"; } static void print(RuntimeArrayType type, DialectAsmPrinter &os) { - os << "rtarray<" << type.getElementType() << ">"; + os << "rtarray<" << type.getElementType(); + if (unsigned stride = type.getArrayStride()) + os << ", stride=" << stride; + os << ">"; } static void print(PointerType type, DialectAsmPrinter &os) { diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -100,7 +100,7 @@ //===----------------------------------------------------------------------===// struct spirv::detail::ArrayTypeStorage : public TypeStorage { - using KeyTy = std::tuple; + using KeyTy = std::tuple; static ArrayTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { @@ -108,28 +108,28 @@ } bool operator==(const KeyTy &key) const { - return key == KeyTy(elementType, getSubclassData(), layoutInfo); + return key == KeyTy(elementType, getSubclassData(), stride); } ArrayTypeStorage(const KeyTy &key) : TypeStorage(std::get<1>(key)), elementType(std::get<0>(key)), - layoutInfo(std::get<2>(key)) {} + stride(std::get<2>(key)) {} Type elementType; - ArrayType::LayoutInfo layoutInfo; + unsigned stride; }; ArrayType ArrayType::get(Type elementType, unsigned elementCount) { assert(elementCount && "ArrayType needs at least one element"); return Base::get(elementType.getContext(), TypeKind::Array, elementType, - elementCount, 0); + elementCount, /*stride=*/0); } ArrayType ArrayType::get(Type elementType, unsigned elementCount, - ArrayType::LayoutInfo layoutInfo) { + unsigned stride) { assert(elementCount && "ArrayType needs at least one element"); return Base::get(elementType.getContext(), TypeKind::Array, elementType, - elementCount, layoutInfo); + elementCount, stride); } unsigned ArrayType::getNumElements() const { @@ -138,10 +138,7 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; } -// ArrayStride must be greater than zero -bool ArrayType::hasLayout() const { return getImpl()->layoutInfo; } - -uint64_t ArrayType::getArrayStride() const { return getImpl()->layoutInfo; } +unsigned ArrayType::getArrayStride() const { return getImpl()->stride; } void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional storage) { @@ -215,7 +212,7 @@ cast().getExtensions(extensions, storage); break; case spirv::TypeKind::RuntimeArray: - cast().getExtensions(extensions, storage); + cast().getExtensions(extensions, storage); break; case spirv::TypeKind::Struct: cast().getExtensions(extensions, storage); @@ -237,7 +234,7 @@ cast().getCapabilities(capabilities, storage); break; case spirv::TypeKind::RuntimeArray: - cast().getCapabilities(capabilities, storage); + cast().getCapabilities(capabilities, storage); break; case spirv::TypeKind::Struct: cast().getCapabilities(capabilities, storage); @@ -523,7 +520,7 @@ //===----------------------------------------------------------------------===// struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage { - using KeyTy = Type; + using KeyTy = std::pair; static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { @@ -531,20 +528,32 @@ RuntimeArrayTypeStorage(key); } - bool operator==(const KeyTy &key) const { return elementType == key; } + bool operator==(const KeyTy &key) const { + return key == KeyTy(elementType, getSubclassData()); + } - RuntimeArrayTypeStorage(const KeyTy &key) : elementType(key) {} + RuntimeArrayTypeStorage(const KeyTy &key) + : TypeStorage(key.second), elementType(key.first) {} Type elementType; }; RuntimeArrayType RuntimeArrayType::get(Type elementType) { return Base::get(elementType.getContext(), TypeKind::RuntimeArray, - elementType); + elementType, /*stride=*/0); +} + +RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) { + return Base::get(elementType.getContext(), TypeKind::RuntimeArray, + elementType, stride); } Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; } +unsigned RuntimeArrayType::getArrayStride() const { + return getImpl()->getSubclassData(); +} + void RuntimeArrayType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, Optional storage) { diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -1203,7 +1203,8 @@ "OpTypeRuntimeArray references undefined ") << operands[1]; } - typeMap[operands[0]] = spirv::RuntimeArrayType::get(memberType); + typeMap[operands[0]] = spirv::RuntimeArrayType::get( + memberType, typeDecorations.lookup(operands[0])); return success(); } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -676,10 +676,19 @@ template <> LogicalResult Serializer::processTypeDecoration( Location loc, spirv::ArrayType type, uint32_t resultID) { - if (type.hasLayout()) { + if (unsigned stride = type.getArrayStride()) { // OpDecorate %arrayTypeSSA ArrayStride strideLiteral - return emitDecoration(resultID, spirv::Decoration::ArrayStride, - {static_cast(type.getArrayStride())}); + return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); + } + return success(); +} + +template <> +LogicalResult Serializer::processTypeDecoration( + Location Loc, spirv::RuntimeArrayType type, uint32_t resultID) { + if (unsigned stride = type.getArrayStride()) { + // OpDecorate %arrayTypeSSA ArrayStride strideLiteral + return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); } return success(); } @@ -1011,9 +1020,9 @@ elementTypeID))) { return failure(); } - operands.push_back(elementTypeID); typeEnum = spirv::Opcode::OpTypeRuntimeArray; - return success(); + operands.push_back(elementTypeID); + return processTypeDecoration(loc, runtimeArrayType, resultID); } if (auto structType = type.dyn_cast()) { diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp @@ -30,14 +30,11 @@ LogicalResult matchAndRewrite(spirv::GlobalVariableOp op, PatternRewriter &rewriter) const override { - spirv::StructType::LayoutInfo structSize = 0; - VulkanLayoutUtils::Size structAlignment = 1; SmallVector globalVarAttrs; auto ptrType = op.type().cast(); auto structType = VulkanLayoutUtils::decorateType( - ptrType.getPointeeType().cast(), structSize, - structAlignment); + ptrType.getPointeeType().cast()); auto decoratedType = spirv::PointerType::get(structType, ptrType.getStorageClass()); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -50,10 +50,8 @@ auto varPointeeType = varPtrType.getPointeeType().cast(); // Set the offset information. - VulkanLayoutUtils::Size size = 0, alignment = 0; varPointeeType = - VulkanLayoutUtils::decorateType(varPointeeType, size, alignment) - .cast(); + VulkanLayoutUtils::decorateType(varPointeeType).cast(); varType = spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass()); diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir --- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir @@ -27,9 +27,9 @@ // CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr, Input> // CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr, Input> // CHECK-LABEL: spv.func @load_store_kernel - // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32{{[}][}]}} - // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32{{[}][}]}} - // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32{{[}][}]}} + // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32{{[}][}]}} + // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32{{[}][}]}} + // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32{{[}][}]}} // CHECK-SAME: [[ARG3:%.*]]: i32 {spv.interface_var_abi = {binding = 3 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} // CHECK-SAME: [[ARG4:%.*]]: i32 {spv.interface_var_abi = {binding = 4 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} // CHECK-SAME: [[ARG5:%.*]]: i32 {spv.interface_var_abi = {binding = 5 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} diff --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/simple.mlir --- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/simple.mlir @@ -5,7 +5,7 @@ // CHECK: spv.module Logical GLSL450 { // CHECK-LABEL: spv.func @basic_module_structure // CHECK-SAME: {{%.*}}: f32 {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} - // CHECK-SAME: {{%.*}}: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32{{[}][}]}} + // CHECK-SAME: {{%.*}}: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32{{[}][}]}} // CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>} gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32>) attributes {gpu.kernel, spv.entry_point_abi = {local_size = dense<[32, 4, 1]>: vector<3xi32>}} { diff --git a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir --- a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir +++ b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir @@ -6,12 +6,12 @@ module attributes {gpu.container_module} { spv.module Logical GLSL450 requires #spv.vce { - spv.globalVariable @kernel_arg_0 bind(0, 0) : !spv.ptr [0]>, StorageBuffer> + spv.globalVariable @kernel_arg_0 bind(0, 0) : !spv.ptr [0]>, StorageBuffer> spv.func @kernel() "None" attributes {workgroup_attributions = 0 : i64} { - %0 = spv._address_of @kernel_arg_0 : !spv.ptr [0]>, StorageBuffer> + %0 = spv._address_of @kernel_arg_0 : !spv.ptr [0]>, StorageBuffer> %2 = spv.constant 0 : i32 - %3 = spv._address_of @kernel_arg_0 : !spv.ptr [0]>, StorageBuffer> - %4 = spv.AccessChain %0[%2, %2] : !spv.ptr [0]>, StorageBuffer> + %3 = spv._address_of @kernel_arg_0 : !spv.ptr [0]>, StorageBuffer> + %4 = spv.AccessChain %0[%2, %2] : !spv.ptr [0]>, StorageBuffer> %5 = spv.Load "StorageBuffer" %4 : f32 spv.Return } diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -312,17 +312,17 @@ %3 = constant dense<[2, 3]> : vector<2xi32> // CHECK: spv.constant 1 : i32 %4 = constant 1 : index - // CHECK: spv.constant dense<1> : tensor<6xi32> : !spv.array<6 x i32 [4]> + // CHECK: spv.constant dense<1> : tensor<6xi32> : !spv.array<6 x i32, stride=4> %5 = constant dense<1> : tensor<2x3xi32> - // CHECK: spv.constant dense<1.000000e+00> : tensor<6xf32> : !spv.array<6 x f32 [4]> + // CHECK: spv.constant dense<1.000000e+00> : tensor<6xf32> : !spv.array<6 x f32, stride=4> %6 = constant dense<1.0> : tensor<2x3xf32> - // CHECK: spv.constant dense<{{\[}}1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf32> : !spv.array<6 x f32 [4]> + // CHECK: spv.constant dense<{{\[}}1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf32> : !spv.array<6 x f32, stride=4> %7 = constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> - // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]> + // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32, stride=4> %8 = constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> - // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]> + // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32, stride=4> %9 = constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32> - // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]> + // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32, stride=4> %10 = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> return } @@ -335,7 +335,7 @@ %1 = constant 5.0 : f16 // CHECK: spv.constant dense<[2, 3]> : vector<2xi16> %2 = constant dense<[2, 3]> : vector<2xi16> - // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf16> : !spv.array<5 x f16 [2]> + // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf16> : !spv.array<5 x f16, stride=2> %3 = constant dense<4.0> : tensor<5xf16> return } @@ -348,7 +348,7 @@ %1 = constant 5.0 : f64 // CHECK: spv.constant dense<[2, 3]> : vector<2xi64> %2 = constant dense<[2, 3]> : vector<2xi64> - // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf64> : !spv.array<5 x f64 [8]> + // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf64> : !spv.array<5 x f64, stride=8> %3 = constant dense<4.0> : tensor<5xf64> return } @@ -373,9 +373,9 @@ %1 = constant 5.0 : f16 // CHECK: spv.constant dense<[2, 3]> : vector<2xi32> %2 = constant dense<[2, 3]> : vector<2xi16> - // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32 [4]> + // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32, stride=4> %3 = constant dense<4.0> : tensor<5xf16> - // CHECK: spv.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32 [4]> + // CHECK: spv.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32, stride=4> %4 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16> return } @@ -388,9 +388,9 @@ %1 = constant 5.0 : f64 // CHECK: spv.constant dense<[2, 3]> : vector<2xi32> %2 = constant dense<[2, 3]> : vector<2xi64> - // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32 [4]> + // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32, stride=4> %3 = constant dense<4.0> : tensor<5xf64> - // CHECK: spv.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32 [4]> + // CHECK: spv.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32, stride=4> %4 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16> return } @@ -572,8 +572,8 @@ //===----------------------------------------------------------------------===// // CHECK-LABEL: @load_store_zero_rank_float -// CHECK: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, -// CHECK: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer>) +// CHECK: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, +// CHECK: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer>) func @load_store_zero_rank_float(%arg0: memref, %arg1: memref) { // CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32 // CHECK: spv.AccessChain [[ARG0]][ @@ -591,8 +591,8 @@ } // CHECK-LABEL: @load_store_zero_rank_int -// CHECK: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, -// CHECK: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer>) +// CHECK: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, +// CHECK: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer>) func @load_store_zero_rank_int(%arg0: memref, %arg1: memref) { // CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32 // CHECK: spv.AccessChain [[ARG0]][ diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir @@ -311,35 +311,35 @@ } { // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return } // CHECK-LABEL: spv.func @memref_8bit_Uniform -// CHECK-SAME: !spv.ptr [0]>, Uniform> +// CHECK-SAME: !spv.ptr [0]>, Uniform> func @memref_8bit_Uniform(%arg0: memref<16xsi8, 4>) { return } // CHECK-LABEL: spv.func @memref_8bit_PushConstant -// CHECK-SAME: !spv.ptr [0]>, PushConstant> +// CHECK-SAME: !spv.ptr [0]>, PushConstant> func @memref_8bit_PushConstant(%arg0: memref<16xui8, 7>) { return } // CHECK-LABEL: spv.func @memref_16bit_StorageBuffer -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> func @memref_16bit_StorageBuffer(%arg0: memref<16xi16, 0>) { return } // CHECK-LABEL: spv.func @memref_16bit_Uniform -// CHECK-SAME: !spv.ptr [0]>, Uniform> +// CHECK-SAME: !spv.ptr [0]>, Uniform> func @memref_16bit_Uniform(%arg0: memref<16xsi16, 4>) { return } // CHECK-LABEL: spv.func @memref_16bit_PushConstant -// CHECK-SAME: !spv.ptr [0]>, PushConstant> +// CHECK-SAME: !spv.ptr [0]>, PushConstant> func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return } // CHECK-LABEL: spv.func @memref_16bit_Input -// CHECK-SAME: !spv.ptr [0]>, Input> +// CHECK-SAME: !spv.ptr [0]>, Input> func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return } // CHECK-LABEL: spv.func @memref_16bit_Output -// CHECK-SAME: !spv.ptr [0]>, Output> +// CHECK-SAME: !spv.ptr [0]>, Output> func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return } } // end module @@ -358,12 +358,12 @@ } { // CHECK-LABEL: spv.func @memref_8bit_PushConstant -// CHECK-SAME: !spv.ptr [0]>, PushConstant> +// CHECK-SAME: !spv.ptr [0]>, PushConstant> func @memref_8bit_PushConstant(%arg0: memref<16xi8, 7>) { return } // CHECK-LABEL: spv.func @memref_16bit_PushConstant -// CHECK-SAME: !spv.ptr [0]>, PushConstant> -// CHECK-SAME: !spv.ptr [0]>, PushConstant> +// CHECK-SAME: !spv.ptr [0]>, PushConstant> +// CHECK-SAME: !spv.ptr [0]>, PushConstant> func @memref_16bit_PushConstant( %arg0: memref<16xi16, 7>, %arg1: memref<16xf16, 7> @@ -385,12 +385,12 @@ } { // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return } // CHECK-LABEL: spv.func @memref_16bit_StorageBuffer -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> -// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> func @memref_16bit_StorageBuffer( %arg0: memref<16xi16, 0>, %arg1: memref<16xf16, 0> @@ -412,12 +412,12 @@ } { // CHECK-LABEL: spv.func @memref_8bit_Uniform -// CHECK-SAME: !spv.ptr [0]>, Uniform> +// CHECK-SAME: !spv.ptr [0]>, Uniform> func @memref_8bit_Uniform(%arg0: memref<16xi8, 4>) { return } // CHECK-LABEL: spv.func @memref_16bit_Uniform -// CHECK-SAME: !spv.ptr [0]>, Uniform> -// CHECK-SAME: !spv.ptr [0]>, Uniform> +// CHECK-SAME: !spv.ptr [0]>, Uniform> +// CHECK-SAME: !spv.ptr [0]>, Uniform> func @memref_16bit_Uniform( %arg0: memref<16xi16, 4>, %arg1: memref<16xf16, 4> @@ -438,11 +438,11 @@ } { // CHECK-LABEL: spv.func @memref_16bit_Input -// CHECK-SAME: !spv.ptr [0]>, Input> +// CHECK-SAME: !spv.ptr [0]>, Input> func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return } // CHECK-LABEL: spv.func @memref_16bit_Output -// CHECK-SAME: !spv.ptr [0]>, Output> +// CHECK-SAME: !spv.ptr [0]>, Output> func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return } } // end module @@ -459,22 +459,22 @@ // CHECK-LABEL: spv.func @memref_offset_strides func @memref_offset_strides( -// CHECK-SAME: !spv.array<64 x f32 [4]> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<72 x f32 [4]> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<256 x f32 [4]> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<64 x f32 [4]> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<88 x f32 [4]> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<64 x f32, stride=4> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<72 x f32, stride=4> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<256 x f32, stride=4> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<64 x f32, stride=4> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<88 x f32, stride=4> [0]>, StorageBuffer> %arg0: memref<16x4xf32, offset: 0, strides: [4, 1]>, // tightly packed; row major %arg1: memref<16x4xf32, offset: 8, strides: [4, 1]>, // offset 8 %arg2: memref<16x4xf32, offset: 0, strides: [16, 1]>, // pad 12 after each row %arg3: memref<16x4xf32, offset: 0, strides: [1, 16]>, // tightly packed; col major %arg4: memref<16x4xf32, offset: 0, strides: [1, 22]>, // pad 4 after each col -// CHECK-SAME: !spv.array<64 x f16 [2]> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<72 x f16 [2]> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<256 x f16 [2]> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<64 x f16 [2]> [0]>, StorageBuffer> -// CHECK-SAME: !spv.array<88 x f16 [2]> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<64 x f16, stride=2> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<72 x f16, stride=2> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<256 x f16, stride=2> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<64 x f16, stride=2> [0]>, StorageBuffer> +// CHECK-SAME: !spv.array<88 x f16, stride=2> [0]>, StorageBuffer> %arg5: memref<16x4xf16, offset: 0, strides: [4, 1]>, %arg6: memref<16x4xf16, offset: 8, strides: [4, 1]>, %arg7: memref<16x4xf16, offset: 0, strides: [16, 1]>, @@ -519,10 +519,10 @@ } { // CHECK-LABEL: spv.func @int_tensor_types -// CHECK-SAME: !spv.array<32 x i64 [8]> -// CHECK-SAME: !spv.array<32 x i32 [4]> -// CHECK-SAME: !spv.array<32 x i16 [2]> -// CHECK-SAME: !spv.array<32 x i8 [1]> +// CHECK-SAME: !spv.array<32 x i64, stride=8> +// CHECK-SAME: !spv.array<32 x i32, stride=4> +// CHECK-SAME: !spv.array<32 x i16, stride=2> +// CHECK-SAME: !spv.array<32 x i8, stride=1> func @int_tensor_types( %arg0: tensor<8x4xi64>, %arg1: tensor<8x4xi32>, @@ -531,9 +531,9 @@ ) { return } // CHECK-LABEL: spv.func @float_tensor_types -// CHECK-SAME: !spv.array<32 x f64 [8]> -// CHECK-SAME: !spv.array<32 x f32 [4]> -// CHECK-SAME: !spv.array<32 x f16 [2]> +// CHECK-SAME: !spv.array<32 x f64, stride=8> +// CHECK-SAME: !spv.array<32 x f32, stride=4> +// CHECK-SAME: !spv.array<32 x f16, stride=2> func @float_tensor_types( %arg0: tensor<8x4xf64>, %arg1: tensor<8x4xf32>, @@ -553,10 +553,10 @@ } { // CHECK-LABEL: spv.func @int_tensor_types -// CHECK-SAME: !spv.array<32 x i32 [4]> -// CHECK-SAME: !spv.array<32 x i32 [4]> -// CHECK-SAME: !spv.array<32 x i32 [4]> -// CHECK-SAME: !spv.array<32 x i32 [4]> +// CHECK-SAME: !spv.array<32 x i32, stride=4> +// CHECK-SAME: !spv.array<32 x i32, stride=4> +// CHECK-SAME: !spv.array<32 x i32, stride=4> +// CHECK-SAME: !spv.array<32 x i32, stride=4> func @int_tensor_types( %arg0: tensor<8x4xi64>, %arg1: tensor<8x4xi32>, @@ -565,9 +565,9 @@ ) { return } // CHECK-LABEL: spv.func @float_tensor_types -// CHECK-SAME: !spv.array<32 x f32 [4]> -// CHECK-SAME: !spv.array<32 x f32 [4]> -// CHECK-SAME: !spv.array<32 x f32 [4]> +// CHECK-SAME: !spv.array<32 x f32, stride=4> +// CHECK-SAME: !spv.array<32 x f32, stride=4> +// CHECK-SAME: !spv.array<32 x f32, stride=4> func @float_tensor_types( %arg0: tensor<8x4xf64>, %arg1: tensor<8x4xf32>, diff --git a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir @@ -16,7 +16,7 @@ //===----------------------------------------------------------------------===// // CHECK-LABEL: @fold_static_stride_subview_with_load -// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i32, [[ARG3:%.*]]: i32, [[ARG4:%.*]]: i32 +// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i32, [[ARG3:%.*]]: i32, [[ARG4:%.*]]: i32 func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) { // CHECK: [[C2:%.*]] = spv.constant 2 // CHECK: [[C3:%.*]] = spv.constant 3 @@ -38,7 +38,7 @@ } // CHECK-LABEL: @fold_static_stride_subview_with_store -// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i32, [[ARG3:%.*]]: i32, [[ARG4:%.*]]: i32, [[ARG5:%.*]]: f32 +// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i32, [[ARG3:%.*]]: i32, [[ARG4:%.*]]: i32, [[ARG5:%.*]]: f32 func @fold_static_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) { // CHECK: [[C2:%.*]] = spv.constant 2 // CHECK: [[C3:%.*]] = spv.constant 3 diff --git a/mlir/test/Dialect/SPIRV/Serialization/array.mlir b/mlir/test/Dialect/SPIRV/Serialization/array.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/array.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/array.mlir @@ -1,9 +1,9 @@ // RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s spv.module Logical GLSL450 requires #spv.vce { - spv.func @array_stride(%arg0 : !spv.ptr [128]>, StorageBuffer>, %arg1 : i32, %arg2 : i32) "None" { - // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [128]>, StorageBuffer> - %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr [128]>, StorageBuffer> + spv.func @array_stride(%arg0 : !spv.ptr, stride=128>, StorageBuffer>, %arg1 : i32, %arg2 : i32) "None" { + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr, stride=128>, StorageBuffer> + %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr, stride=128>, StorageBuffer> spv.Return } } @@ -11,8 +11,8 @@ // ----- spv.module Logical GLSL450 requires #spv.vce { - // CHECK: spv.globalVariable {{@.*}} : !spv.ptr, StorageBuffer> - spv.globalVariable @var0 : !spv.ptr, StorageBuffer> + // CHECK: spv.globalVariable {{@.*}} : !spv.ptr, StorageBuffer> + spv.globalVariable @var0 : !spv.ptr, StorageBuffer> // CHECK: spv.globalVariable {{@.*}} : !spv.ptr>, Input> spv.globalVariable @var1 : !spv.ptr>, Input> } diff --git a/mlir/test/Dialect/SPIRV/Serialization/constant.mlir b/mlir/test/Dialect/SPIRV/Serialization/constant.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/constant.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/constant.mlir @@ -226,16 +226,16 @@ } // CHECK-LABEL: @multi_dimensions_const - spv.func @multi_dimensions_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>) "None" { - // CHECK: spv.constant {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]], {{\[}}[7 : i32, 8 : i32, 9 : i32], [10 : i32, 11 : i32, 12 : i32]]] : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]> - %0 = spv.constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32> : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]> - spv.ReturnValue %0 : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]> + spv.func @multi_dimensions_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32, stride=4>, stride=12>, stride=24>) "None" { + // CHECK: spv.constant {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]], {{\[}}[7 : i32, 8 : i32, 9 : i32], [10 : i32, 11 : i32, 12 : i32]]] : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32, stride=4>, stride=12>, stride=24> + %0 = spv.constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32> : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32, stride=4>, stride=12>, stride=24> + spv.ReturnValue %0 : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32, stride=4>, stride=12>, stride=24> } // CHECK-LABEL: @multi_dimensions_splat_const - spv.func @multi_dimensions_splat_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>) "None" { - // CHECK: spv.constant {{\[}}{{\[}}[1 : i32, 1 : i32, 1 : i32], [1 : i32, 1 : i32, 1 : i32]], {{\[}}[1 : i32, 1 : i32, 1 : i32], [1 : i32, 1 : i32, 1 : i32]]] : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]> - %0 = spv.constant dense<1> : tensor<2x2x3xi32> : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]> - spv.ReturnValue %0 : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]> + spv.func @multi_dimensions_splat_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32, stride=4>, stride=12>, stride=24>) "None" { + // CHECK: spv.constant {{\[}}{{\[}}[1 : i32, 1 : i32, 1 : i32], [1 : i32, 1 : i32, 1 : i32]], {{\[}}[1 : i32, 1 : i32, 1 : i32], [1 : i32, 1 : i32, 1 : i32]]] : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32, stride=4>, stride=12>, stride=24> + %0 = spv.constant dense<1> : tensor<2x2x3xi32> : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32, stride=4>, stride=12>, stride=24> + spv.ReturnValue %0 : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32, stride=4>, stride=12>, stride=24> } } diff --git a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir @@ -60,14 +60,14 @@ // ----- spv.module Logical GLSL450 requires #spv.vce { - spv.globalVariable @GV1 bind(0, 0) : !spv.ptr [0]>, StorageBuffer> - spv.globalVariable @GV2 bind(0, 1) : !spv.ptr [0]>, StorageBuffer> + spv.globalVariable @GV1 bind(0, 0) : !spv.ptr [0]>, StorageBuffer> + spv.globalVariable @GV2 bind(0, 1) : !spv.ptr [0]>, StorageBuffer> spv.func @loop_kernel() "None" { - %0 = spv._address_of @GV1 : !spv.ptr [0]>, StorageBuffer> + %0 = spv._address_of @GV1 : !spv.ptr [0]>, StorageBuffer> %1 = spv.constant 0 : i32 - %2 = spv.AccessChain %0[%1] : !spv.ptr [0]>, StorageBuffer> - %3 = spv._address_of @GV2 : !spv.ptr [0]>, StorageBuffer> - %5 = spv.AccessChain %3[%1] : !spv.ptr [0]>, StorageBuffer> + %2 = spv.AccessChain %0[%1] : !spv.ptr [0]>, StorageBuffer> + %3 = spv._address_of @GV2 : !spv.ptr [0]>, StorageBuffer> + %5 = spv.AccessChain %3[%1] : !spv.ptr [0]>, StorageBuffer> %6 = spv.constant 4 : i32 %7 = spv.constant 42 : i32 %8 = spv.constant 2 : i32 @@ -84,9 +84,9 @@ spv.BranchConditional %10, ^body, ^merge // CHECK-NEXT: ^bb2: // pred: ^bb1 ^body: - %11 = spv.AccessChain %2[%9] : !spv.ptr, StorageBuffer> + %11 = spv.AccessChain %2[%9] : !spv.ptr, StorageBuffer> %12 = spv.Load "StorageBuffer" %11 : f32 - %13 = spv.AccessChain %5[%9] : !spv.ptr, StorageBuffer> + %13 = spv.AccessChain %5[%9] : !spv.ptr, StorageBuffer> spv.Store "StorageBuffer" %13, %12 : f32 // CHECK: %[[ADD:.*]] = spv.IAdd %14 = spv.IAdd %9, %8 : i32 diff --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir @@ -27,32 +27,32 @@ // ----- spv.module Logical GLSL450 requires #spv.vce { - spv.func @load_store_zero_rank_float(%arg0: !spv.ptr [0]>, StorageBuffer>, %arg1: !spv.ptr [0]>, StorageBuffer>) "None" { - // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + spv.func @load_store_zero_rank_float(%arg0: !spv.ptr [0]>, StorageBuffer>, %arg1: !spv.ptr [0]>, StorageBuffer>) "None" { + // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : f32 %0 = spv.constant 0 : i32 - %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer> + %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer> %2 = spv.Load "StorageBuffer" %1 : f32 - // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32 %3 = spv.constant 0 : i32 - %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer> + %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer> spv.Store "StorageBuffer" %4, %2 : f32 spv.Return } - spv.func @load_store_zero_rank_int(%arg0: !spv.ptr [0]>, StorageBuffer>, %arg1: !spv.ptr [0]>, StorageBuffer>) "None" { - // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + spv.func @load_store_zero_rank_int(%arg0: !spv.ptr [0]>, StorageBuffer>, %arg1: !spv.ptr [0]>, StorageBuffer>) "None" { + // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : i32 %0 = spv.constant 0 : i32 - %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer> + %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer> %2 = spv.Load "StorageBuffer" %1 : i32 - // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32 %3 = spv.constant 0 : i32 - %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer> + %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer> spv.Store "StorageBuffer" %4, %2 : i32 spv.Return } diff --git a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir @@ -1,17 +1,17 @@ // RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s spv.module Logical GLSL450 requires #spv.vce { - // CHECK: !spv.ptr [0]>, Input> - spv.globalVariable @var0 bind(0, 1) : !spv.ptr [0]>, Input> + // CHECK: !spv.ptr [0]>, Input> + spv.globalVariable @var0 bind(0, 1) : !spv.ptr [0]>, Input> - // CHECK: !spv.ptr [4]> [4]>, Input> - spv.globalVariable @var1 bind(0, 2) : !spv.ptr [4]> [4]>, Input> + // CHECK: !spv.ptr [4]> [4]>, Input> + spv.globalVariable @var1 bind(0, 2) : !spv.ptr [4]> [4]>, Input> // CHECK: !spv.ptr, StorageBuffer> spv.globalVariable @var2 : !spv.ptr, StorageBuffer> - // CHECK: !spv.ptr [0]> [4]> [0]>, StorageBuffer> - spv.globalVariable @var3 : !spv.ptr [0]> [4]> [0]>, StorageBuffer> + // CHECK: !spv.ptr [0]>, stride=512> [0]>, StorageBuffer> + spv.globalVariable @var3 : !spv.ptr [0]>, stride=512> [0]>, StorageBuffer> // CHECK: !spv.ptr, StorageBuffer> spv.globalVariable @var4 : !spv.ptr, StorageBuffer> @@ -25,9 +25,9 @@ // CHECK: !spv.ptr, StorageBuffer> spv.globalVariable @empty : !spv.ptr, StorageBuffer> - // CHECK: !spv.ptr [0]>, Input>, - // CHECK-SAME: !spv.ptr [0]>, Output> - spv.func @kernel_1(%arg0: !spv.ptr [0]>, Input>, %arg1: !spv.ptr [0]>, Output>) -> () "None" { + // CHECK: !spv.ptr [0]>, Input>, + // CHECK-SAME: !spv.ptr [0]>, Output> + spv.func @kernel(%arg0: !spv.ptr [0]>, Input>, %arg1: !spv.ptr [0]>, Output>) -> () "None" { spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir @@ -10,7 +10,7 @@ // CHECK-LABEL: spv.module spv.module Logical GLSL450 { // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr [0]>, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr [0]>, StorageBuffer> // CHECK: spv.func [[FN:@.*]]() spv.func @kernel( %arg0: f32 diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir @@ -17,9 +17,9 @@ spv.globalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spv.ptr, Input> // CHECK-DAG: spv.globalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId") spv.globalVariable @__builtin_var_WorkgroupId__ built_in("WorkgroupId") : !spv.ptr, Input> - // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr [16]> [0]>, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr [16]> [0]>, StorageBuffer> - // CHECK-DAG: spv.globalVariable [[VAR2:@.*]] bind(0, 2) : !spv.ptr [16]> [0]>, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr, stride=16> [0]>, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr, stride=16> [0]>, StorageBuffer> + // CHECK-DAG: spv.globalVariable [[VAR2:@.*]] bind(0, 2) : !spv.ptr, stride=16> [0]>, StorageBuffer> // CHECK-DAG: spv.globalVariable [[VAR3:@.*]] bind(0, 3) : !spv.ptr, StorageBuffer> // CHECK-DAG: spv.globalVariable [[VAR4:@.*]] bind(0, 4) : !spv.ptr, StorageBuffer> // CHECK-DAG: spv.globalVariable [[VAR5:@.*]] bind(0, 5) : !spv.ptr, StorageBuffer> diff --git a/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir b/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir @@ -4,19 +4,19 @@ // CHECK: spv.globalVariable @var0 bind(0, 1) : !spv.ptr [4], f32 [12]>, Uniform> spv.globalVariable @var0 bind(0,1) : !spv.ptr, f32>, Uniform> - // CHECK: spv.globalVariable @var1 bind(0, 2) : !spv.ptr [0], f32 [256]>, StorageBuffer> + // CHECK: spv.globalVariable @var1 bind(0, 2) : !spv.ptr [0], f32 [256]>, StorageBuffer> spv.globalVariable @var1 bind(0,2) : !spv.ptr, f32>, StorageBuffer> - // CHECK: spv.globalVariable @var2 bind(1, 0) : !spv.ptr [0], f32 [256]> [0], i32 [260]>, StorageBuffer> + // CHECK: spv.globalVariable @var2 bind(1, 0) : !spv.ptr [0], f32 [256]> [0], i32 [260]>, StorageBuffer> spv.globalVariable @var2 bind(1,0) : !spv.ptr, f32>, i32>, StorageBuffer> - // CHECK: spv.globalVariable @var3 : !spv.ptr [8]> [72]> [0], f32 [1152]>, StorageBuffer> + // CHECK: spv.globalVariable @var3 : !spv.ptr [8]>, stride=72> [0], f32 [1152]>, StorageBuffer> spv.globalVariable @var3 : !spv.ptr>>, f32>, StorageBuffer> // CHECK: spv.globalVariable @var4 bind(1, 2) : !spv.ptr [0], f32 [16], i1 [20]> [0], i1 [24]>, StorageBuffer> spv.globalVariable @var4 bind(1,2) : !spv.ptr, f32, i1>, i1>, StorageBuffer> - // CHECK: spv.globalVariable @var5 bind(1, 3) : !spv.ptr [0]>, StorageBuffer> + // CHECK: spv.globalVariable @var5 bind(1, 3) : !spv.ptr [0]>, StorageBuffer> spv.globalVariable @var5 bind(1,3) : !spv.ptr>, StorageBuffer> spv.func @kernel() -> () "None" { @@ -38,10 +38,10 @@ // CHECK: spv.globalVariable @var1 : !spv.ptr [8], f32 [24]> [0], f32 [32]>, Uniform> spv.globalVariable @var1 : !spv.ptr, f32>, f32>, Uniform> - // CHECK: spv.globalVariable @var2 : !spv.ptr [128]> [8]> [8], f32 [2064]> [0], f32 [2072]>, Uniform> + // CHECK: spv.globalVariable @var2 : !spv.ptr, stride=128> [8]> [8], f32 [2064]> [0], f32 [2072]>, Uniform> spv.globalVariable @var2 : !spv.ptr>>, f32>, f32>, Uniform> - // CHECK: spv.globalVariable @var3 : !spv.ptr [0], i1 [512]> [0], i1 [520]>, Uniform> + // CHECK: spv.globalVariable @var3 : !spv.ptr [0], i1 [512]> [0], i1 [520]>, Uniform> spv.globalVariable @var3 : !spv.ptr, i1>, i1>, Uniform> // CHECK: spv.globalVariable @var4 : !spv.ptr [8], i1 [24]>, Uniform> diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -58,20 +58,20 @@ // CHECK: %2 = spv.constant 5.000000e-01 : f32 // CHECK: %3 = spv.constant dense<[2, 3]> : vector<2xi32> // CHECK: %4 = spv.constant [dense<3.000000e+00> : vector<2xf32>] : !spv.array<1 x vector<2xf32>> - // CHECK: %5 = spv.constant dense<1> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]> - // CHECK: %6 = spv.constant dense<1.000000e+00> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]> - // CHECK: %7 = spv.constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]> - // CHECK: %8 = spv.constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]> + // CHECK: %5 = spv.constant dense<1> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32>> + // CHECK: %6 = spv.constant dense<1.000000e+00> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32>> + // CHECK: %7 = spv.constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32>> + // CHECK: %8 = spv.constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32>> %0 = spv.constant true %1 = spv.constant 42 : i32 %2 = spv.constant 0.5 : f32 %3 = spv.constant dense<[2, 3]> : vector<2xi32> %4 = spv.constant [dense<3.0> : vector<2xf32>] : !spv.array<1xvector<2xf32>> - %5 = spv.constant dense<1> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]> - %6 = spv.constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]> - %7 = spv.constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]> - %8 = spv.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]> + %5 = spv.constant dense<1> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32>> + %6 = spv.constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32>> + %7 = spv.constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32>> + %8 = spv.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32>> return } @@ -118,14 +118,14 @@ func @value_result_type_mismatch() -> () { // expected-error @+1 {{result element type ('i32') does not match value element type ('f32')}} - %0 = spv.constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]> + %0 = spv.constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x i32>> } // ----- func @value_result_num_elements_mismatch() -> () { // expected-error @+1 {{result number of elements (6) does not match value number of elements (4)}} - %0 = spv.constant dense<1.0> : tensor<2x2xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]> + %0 = spv.constant dense<1.0> : tensor<2x2xf32> : !spv.array<2 x !spv.array<3 x f32>> return } diff --git a/mlir/test/Dialect/SPIRV/types.mlir b/mlir/test/Dialect/SPIRV/types.mlir --- a/mlir/test/Dialect/SPIRV/types.mlir +++ b/mlir/test/Dialect/SPIRV/types.mlir @@ -12,8 +12,8 @@ // CHECK: func @vector_array_type(!spv.array<32 x vector<4xf32>>) func @vector_array_type(!spv.array< 32 x vector<4xf32> >) -> () -// CHECK: func @array_type_stride(!spv.array<4 x !spv.array<4 x f32 [4]> [128]>) -func @array_type_stride(!spv.array< 4 x !spv.array<4 x f32 [4]> [128]>) -> () +// CHECK: func @array_type_stride(!spv.array<4 x !spv.array<4 x f32, stride=4>, stride=128>) +func @array_type_stride(!spv.array< 4 x !spv.array<4 x f32, stride=4>, stride = 128>) -> () // ----- @@ -78,7 +78,7 @@ // ----- // expected-error @+1 {{ArrayStride must be greater than zero}} -func @array_type_zero_stride(!spv.array<4xi32 [0]>) -> () +func @array_type_zero_stride(!spv.array<4xi32, stride=0>) -> () // ----- @@ -132,6 +132,9 @@ // CHECK: func @vector_runtime_array_type(!spv.rtarray>) func @vector_runtime_array_type(!spv.rtarray< vector<4xf32> >) -> () +// CHECK: func @runtime_array_type_stride(!spv.rtarray) +func @runtime_array_type_stride(!spv.rtarray) -> () + // ----- // expected-error @+1 {{expected '<'}} @@ -149,6 +152,11 @@ // ----- +// expected-error @+1 {{ArrayStride must be greater than zero}} +func @runtime_array_type_zero_stride(!spv.rtarray) -> () + +// ----- + //===----------------------------------------------------------------------===// // ImageType //===----------------------------------------------------------------------===//