diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md --- a/mlir/docs/Dialects/SPIR-V.md +++ b/mlir/docs/Dialects/SPIR-V.md @@ -880,14 +880,30 @@ * `spv.entry_point_abi` is a struct attribute that should be attached to the entry function. It contains: * `local_size` for specifying the local work group size for the dispatch. -* `spv.interface_var_abi` is a struct attribute that should be attached to - each operand and result of the entry function. It contains: - * `descriptor_set` for specifying the descriptor set number for the - corresponding resource variable. - * `binding` for specifying the binding number for the corresponding - resource variable. - * `storage_class` for specifying the storage class for the corresponding - resource variable. +* `spv.interface_var_abi` is attribute that should be attached to each operand + and result of the entry function. It should be of `#spv.interface_var_abi` + attribute kind, which is defined as: + +``` +spv-storage-class ::= `StorageBuffer` | ... +spv-descriptor-set ::= integer +spv-binding ::= integer +spv-interface-var-abi ::= `#` `spv.interface_var_abi` `<(` spv-descriptor-set + `,` spv-binding `)` (`,` spv-storage-class)? `>` +``` + +For example, + +``` +#spv.interface_var_abi<(0, 0), StorageBuffer> +#spv.interface_var_abi<(0, 1)> +``` + +The attribute has a few fields: + +* Descriptor set number for the corresponding resource variable. +* Binding number for the corresponding resource variable. +* Storage class for the corresponding resource variable. The SPIR-V dialect provides a [`LowerABIAttributesPass`][MlirSpirvPasses] for consuming these attributes and create SPIR-V module complying with the diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SPIRV_SPIRVATTRIBUTES_H #define MLIR_DIALECT_SPIRV_SPIRVATTRIBUTES_H +#include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Attributes.h" #include "mlir/Support/LLVM.h" @@ -26,6 +27,7 @@ enum class Version : uint32_t; namespace detail { +struct InterfaceVarABIAttributeStorage; struct TargetEnvAttributeStorage; struct VerCapExtAttributeStorage; } // namespace detail @@ -33,11 +35,49 @@ /// SPIR-V dialect-specific attribute kinds. namespace AttrKind { enum Kind { - TargetEnv = Attribute::FIRST_SPIRV_ATTR, /// Target environment + InterfaceVarABI = Attribute::FIRST_SPIRV_ATTR, /// Interface var ABI + TargetEnv, /// Target environment VerCapExt, /// (version, extension, capability) triple }; } // namespace AttrKind +/// An attribute that specifies the information regarding the global variable: +/// descriptor set, binding, storage class. +class InterfaceVarABIAttr + : public Attribute::AttrBase { +public: + using Base::Base; + + /// Gets a InterfaceVarABIAttr. + static InterfaceVarABIAttr get(uint32_t descirptorSet, uint32_t binding, + Optional storageClass, + MLIRContext *context); + static InterfaceVarABIAttr get(IntegerAttr descriptorSet, IntegerAttr binding, + IntegerAttr storageClass); + + /// Returns the attribute kind's name (without the 'spv.' prefix). + static StringRef getKindName(); + + /// Returns descriptor set. + uint32_t getDescriptorSet(); + + /// Returns binding. + uint32_t getBinding(); + + /// Returns `spirv::StorageClass`. + Optional getStorageClass(); + + static bool kindof(unsigned kind) { + return kind == AttrKind::InterfaceVarABI; + } + + static LogicalResult verifyConstructionInvariants(Location loc, + IntegerAttr descriptorSet, + IntegerAttr binding, + IntegerAttr storageClass); +}; + /// An attribute that specifies the SPIR-V (version, capabilities, extensions) /// triple. class VerCapExtAttr diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td @@ -23,18 +23,6 @@ include "mlir/Dialect/SPIRV/SPIRVBase.td" -// For arguments that eventually map to spv.globalVariable for the -// shader interface, this attribute specifies the information regarding -// the global variable: -// 1) Descriptor Set. -// 2) Binding number. -// 3) Storage class. -def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPIRV_Dialect, [ - StructFieldAttr<"descriptor_set", I32Attr>, - StructFieldAttr<"binding", I32Attr>, - StructFieldAttr<"storage_class", OptionalAttr> -]>; - // For entry functions, this attribute specifies information related to entry // points in the generated SPIR-V module: // 1) WorkGroup Size. diff --git a/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp @@ -25,6 +25,32 @@ namespace spirv { namespace detail { + +struct InterfaceVarABIAttributeStorage : public AttributeStorage { + using KeyTy = std::tuple; + + InterfaceVarABIAttributeStorage(Attribute descriptorSet, Attribute binding, + Attribute storageClass) + : descriptorSet(descriptorSet), binding(binding), + storageClass(storageClass) {} + + bool operator==(const KeyTy &key) const { + return std::get<0>(key) == descriptorSet && std::get<1>(key) == binding && + std::get<2>(key) == storageClass; + } + + static InterfaceVarABIAttributeStorage * + construct(AttributeStorageAllocator &allocator, const KeyTy &key) { + return new (allocator.allocate()) + InterfaceVarABIAttributeStorage(std::get<0>(key), std::get<1>(key), + std::get<2>(key)); + } + + Attribute descriptorSet; + Attribute binding; + Attribute storageClass; +}; + struct VerCapExtAttributeStorage : public AttributeStorage { using KeyTy = std::tuple; @@ -72,6 +98,74 @@ } // namespace spirv } // namespace mlir +//===----------------------------------------------------------------------===// +// InterfaceVarABIAttr +//===----------------------------------------------------------------------===// + +spirv::InterfaceVarABIAttr +spirv::InterfaceVarABIAttr::get(uint32_t descriptorSet, uint32_t binding, + Optional storageClass, + MLIRContext *context) { + Builder b(context); + auto descriptorSetAttr = b.getI32IntegerAttr(descriptorSet); + auto bindingAttr = b.getI32IntegerAttr(binding); + auto storageClassAttr = + storageClass ? b.getI32IntegerAttr(static_cast(*storageClass)) + : IntegerAttr(); + return get(descriptorSetAttr, bindingAttr, storageClassAttr); +} + +spirv::InterfaceVarABIAttr +spirv::InterfaceVarABIAttr::get(IntegerAttr descriptorSet, IntegerAttr binding, + IntegerAttr storageClass) { + assert(descriptorSet && binding); + MLIRContext *context = descriptorSet.getContext(); + return Base::get(context, spirv::AttrKind::InterfaceVarABI, descriptorSet, + binding, storageClass); +} + +StringRef spirv::InterfaceVarABIAttr::getKindName() { + return "interface_var_abi"; +} + +uint32_t spirv::InterfaceVarABIAttr::getBinding() { + return getImpl()->binding.cast().getInt(); +} + +uint32_t spirv::InterfaceVarABIAttr::getDescriptorSet() { + return getImpl()->descriptorSet.cast().getInt(); +} + +Optional spirv::InterfaceVarABIAttr::getStorageClass() { + if (getImpl()->storageClass) + return static_cast( + getImpl()->storageClass.cast().getValue().getZExtValue()); + return llvm::None; +} + +LogicalResult spirv::InterfaceVarABIAttr::verifyConstructionInvariants( + Location loc, IntegerAttr descriptorSet, IntegerAttr binding, + IntegerAttr storageClass) { + if (!descriptorSet.getType().isSignlessInteger(32)) + return emitError(loc, "expected 32-bit integer for descriptor set"); + + if (!binding.getType().isSignlessInteger(32)) + return emitError(loc, "expected 32-bit integer for binding"); + + if (storageClass) { + if (auto storageClassAttr = storageClass.cast()) { + auto storageClassValue = + spirv::symbolizeStorageClass(storageClassAttr.getInt()); + if (!storageClassValue) + return emitError(loc, "unknown storage class"); + } else { + return emitError(loc, "expected valid storage class"); + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // VerCapExtAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -118,7 +118,7 @@ : Dialect(getDialectNamespace(), context) { addTypes(); - addAttributes(); + addAttributes(); // Add SPIR-V ops. addOperations< @@ -628,6 +628,75 @@ return success(); } +/// Parses a spirv::InterfaceVarABIAttr. +static Attribute parseInterfaceVarABIAttr(DialectAsmParser &parser) { + if (parser.parseLess()) + return {}; + + Builder &builder = parser.getBuilder(); + + if (parser.parseLParen()) + return {}; + + IntegerAttr descriptorSetAttr; + { + auto loc = parser.getCurrentLocation(); + uint32_t descriptorSet = 0; + auto descriptorSetParseResult = parser.parseOptionalInteger(descriptorSet); + + if (!descriptorSetParseResult.hasValue() || + failed(*descriptorSetParseResult)) { + parser.emitError(loc, "missing descriptor set"); + return {}; + } + descriptorSetAttr = builder.getI32IntegerAttr(descriptorSet); + } + + if (parser.parseComma()) + return {}; + + IntegerAttr bindingAttr; + { + auto loc = parser.getCurrentLocation(); + uint32_t binding = 0; + auto bindingParseResult = parser.parseOptionalInteger(binding); + + if (!bindingParseResult.hasValue() || failed(*bindingParseResult)) { + parser.emitError(loc, "missing binding"); + return {}; + } + bindingAttr = builder.getI32IntegerAttr(binding); + } + + if (parser.parseRParen()) + return {}; + + IntegerAttr storageClassAttr; + { + if (succeeded(parser.parseOptionalComma())) { + auto loc = parser.getCurrentLocation(); + StringRef storageClass; + if (parser.parseKeyword(&storageClass)) + return {}; + + if (auto storageClassSymbol = + spirv::symbolizeStorageClass(storageClass)) { + storageClassAttr = builder.getI32IntegerAttr( + static_cast(*storageClassSymbol)); + } else { + parser.emitError(loc, "unknown storage class: ") << storageClass; + return {}; + } + } + } + + if (parser.parseGreater()) + return {}; + + return spirv::InterfaceVarABIAttr::get(descriptorSetAttr, bindingAttr, + storageClassAttr); +} + static Attribute parseVerCapExtAttr(DialectAsmParser &parser) { if (parser.parseLess()) return {}; @@ -750,6 +819,8 @@ return parseTargetEnvAttr(parser); if (attrKind == spirv::VerCapExtAttr::getKindName()) return parseVerCapExtAttr(parser); + if (attrKind == spirv::InterfaceVarABIAttr::getKindName()) + return parseInterfaceVarABIAttr(parser); parser.emitError(parser.getNameLoc(), "unknown SPIR-V attribute kind: ") << attrKind; @@ -780,12 +851,25 @@ printer << ", " << targetEnv.getResourceLimits() << ">"; } +static void print(spirv::InterfaceVarABIAttr interfaceVarABIAttr, + DialectAsmPrinter &printer) { + printer << spirv::InterfaceVarABIAttr::getKindName() << "<(" + << interfaceVarABIAttr.getDescriptorSet() << ", " + << interfaceVarABIAttr.getBinding() << ")"; + auto storageClass = interfaceVarABIAttr.getStorageClass(); + if (storageClass) + printer << ", " << spirv::stringifyStorageClass(*storageClass); + printer << ">"; +} + void SPIRVDialect::printAttribute(Attribute attr, DialectAsmPrinter &printer) const { if (auto targetEnv = attr.dyn_cast()) print(targetEnv, printer); else if (auto vceAttr = attr.dyn_cast()) print(vceAttr, printer); + else if (auto interfaceVarABIAttr = attr.dyn_cast()) + print(interfaceVarABIAttr, printer); else llvm_unreachable("unhandled SPIR-V attribute kind"); } @@ -845,11 +929,9 @@ auto varABIAttr = attr.dyn_cast(); if (!varABIAttr) return emitError(loc, "'") - << symbol - << "' attribute must be a dictionary attribute containing two or " - "three 32-bit integer attributes: 'descriptor_set', 'binding', " - "and optional 'storage_class'"; - if (varABIAttr.storage_class() && !valueType.isIntOrIndexOrFloat()) + << symbol << "' must be a spirv::InterfaceVarABIAttr"; + + if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat()) return emitError(loc, "'") << symbol << "' attribute cannot specify storage class " "when attaching to a non-scalar value"; diff --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp --- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp @@ -86,14 +86,8 @@ spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, Optional storageClass, MLIRContext *context) { - Type i32Type = IntegerType::get(32, context); - auto scAttr = - storageClass - ? IntegerAttr::get(i32Type, static_cast(*storageClass)) - : IntegerAttr(); - return spirv::InterfaceVarABIAttr::get( - IntegerAttr::get(i32Type, descriptorSet), - IntegerAttr::get(i32Type, binding), scAttr, context); + return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass, + context); } StringRef spirv::getEntryPointABIAttrName() { return "spv.entry_point_abi"; } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -41,10 +41,11 @@ // it must already be a !spv.ptr>. auto varType = funcOp.getType().getInput(argIndex); if (varType.cast().isScalarOrVector()) { - auto storageClass = - static_cast(abiInfo.storage_class().getInt()); + auto storageClass = abiInfo.getStorageClass(); + if (!storageClass) + return nullptr; varType = - spirv::PointerType::get(spirv::StructType::get(varType), storageClass); + spirv::PointerType::get(spirv::StructType::get(varType), *storageClass); } auto varPtrType = varType.cast(); auto varPointeeType = varPtrType.getPointeeType().cast(); @@ -58,8 +59,8 @@ spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass()); return builder.create( - funcOp.getLoc(), varType, varName, abiInfo.descriptor_set().getInt(), - abiInfo.binding().getInt()); + funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(), + abiInfo.getBinding()); } /// Gets the global variables that need to be specified as interface variable diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir --- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir @@ -27,13 +27,13 @@ // CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr, Input> // CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr, Input> // CHECK-LABEL: spv.func @load_store_kernel - // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32{{[}][}]}} - // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32{{[}][}]}} - // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32{{[}][}]}} - // CHECK-SAME: [[ARG3:%.*]]: i32 {spv.interface_var_abi = {binding = 3 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} - // CHECK-SAME: [[ARG4:%.*]]: i32 {spv.interface_var_abi = {binding = 4 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} - // CHECK-SAME: [[ARG5:%.*]]: i32 {spv.interface_var_abi = {binding = 5 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} - // CHECK-SAME: [[ARG6:%.*]]: i32 {spv.interface_var_abi = {binding = 6 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} + // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 0)>} + // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>} + // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 2)>} + // CHECK-SAME: [[ARG3:%.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 3), StorageBuffer>} + // CHECK-SAME: [[ARG4:%.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 4), StorageBuffer>} + // CHECK-SAME: [[ARG5:%.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 5), StorageBuffer>} + // CHECK-SAME: [[ARG6:%.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 6), StorageBuffer>} gpu.func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index) attributes {gpu.kernel, spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} { // CHECK: [[ADDRESSWORKGROUPID:%.*]] = spv._address_of [[WORKGROUPIDVAR]] diff --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/simple.mlir --- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/simple.mlir @@ -4,8 +4,8 @@ gpu.module @kernels { // CHECK: spv.module Logical GLSL450 { // CHECK-LABEL: spv.func @basic_module_structure - // CHECK-SAME: {{%.*}}: f32 {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} - // CHECK-SAME: {{%.*}}: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32{{[}][}]}} + // CHECK-SAME: {{%.*}}: f32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 0), StorageBuffer>} + // CHECK-SAME: {{%.*}}: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>} // CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>} gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32>) attributes {gpu.kernel, spv.entry_point_abi = {local_size = dense<[32, 4, 1]>: vector<3xi32>}} { diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir @@ -14,12 +14,9 @@ // CHECK: spv.func [[FN:@.*]]() spv.func @kernel( %arg0: f32 - {spv.interface_var_abi = {binding = 0 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}}, + {spv.interface_var_abi = #spv.interface_var_abi<(0, 0), StorageBuffer>}, %arg1: !spv.ptr>, StorageBuffer> - {spv.interface_var_abi = {binding = 1 : i32, - descriptor_set = 0 : i32}}) "None" + {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}) "None" attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} { // CHECK: [[ARG1:%.*]] = spv._address_of [[VAR1]] // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]] diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir @@ -27,30 +27,19 @@ // CHECK: spv.func [[FN:@.*]]() spv.func @load_store_kernel( %arg0: !spv.ptr>>, StorageBuffer> - {spv.interface_var_abi = {binding = 0 : i32, - descriptor_set = 0 : i32}}, + {spv.interface_var_abi = #spv.interface_var_abi<(0, 0)>}, %arg1: !spv.ptr>>, StorageBuffer> - {spv.interface_var_abi = {binding = 1 : i32, - descriptor_set = 0 : i32}}, + {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}, %arg2: !spv.ptr>>, StorageBuffer> - {spv.interface_var_abi = {binding = 2 : i32, - descriptor_set = 0 : i32}}, + {spv.interface_var_abi = #spv.interface_var_abi<(0, 2)>}, %arg3: i32 - {spv.interface_var_abi = {binding = 3 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}}, + {spv.interface_var_abi = #spv.interface_var_abi<(0, 3), StorageBuffer>}, %arg4: i32 - {spv.interface_var_abi = {binding = 4 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}}, + {spv.interface_var_abi = #spv.interface_var_abi<(0, 4), StorageBuffer>}, %arg5: i32 - {spv.interface_var_abi = {binding = 5 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}}, + {spv.interface_var_abi = #spv.interface_var_abi<(0, 5), StorageBuffer>}, %arg6: i32 - {spv.interface_var_abi = {binding = 6 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}}) "None" + {spv.interface_var_abi = #spv.interface_var_abi<(0, 6), StorageBuffer>}) "None" attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} { // CHECK: [[ADDRESSARG6:%.*]] = spv._address_of [[VAR6]] // CHECK: [[CONST6:%.*]] = spv.constant 0 : i32 diff --git a/mlir/test/Dialect/SPIRV/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/target-and-abi.mlir --- a/mlir/test/Dialect/SPIRV/target-and-abi.mlir +++ b/mlir/test/Dialect/SPIRV/target-and-abi.mlir @@ -51,34 +51,51 @@ // spv.interface_var_abi //===----------------------------------------------------------------------===// -// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing two or three 32-bit integer attributes: 'descriptor_set', 'binding', and optional 'storage_class'}} +// expected-error @+1 {{'spv.interface_var_abi' must be a spirv::InterfaceVarABIAttr}} func @interface_var( %arg0 : f32 {spv.interface_var_abi = 64} ) { return } // ----- -// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing two or three 32-bit integer attributes: 'descriptor_set', 'binding', and optional 'storage_class'}} func @interface_var( - %arg0 : f32 {spv.interface_var_abi = {binding = 0: i32}} +// expected-error @+1 {{missing descriptor set}} + %arg0 : f32 {spv.interface_var_abi = #spv.interface_var_abi<()>} ) { return } // ----- -// CHECK: {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32}} func @interface_var( - %arg0 : f32 {spv.interface_var_abi = {binding = 0 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}} +// expected-error @+1 {{missing binding}} + %arg0 : f32 {spv.interface_var_abi = #spv.interface_var_abi<(1,)>} +) { return } + +// ----- + +func @interface_var( +// expected-error @+1 {{unknown storage class: }} + %arg0 : f32 {spv.interface_var_abi = #spv.interface_var_abi<(1,2), Foo>} +) { return } + +// ----- + +// CHECK: {spv.interface_var_abi = #spv.interface_var_abi<(0, 1), Uniform>} +func @interface_var( + %arg0 : f32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 1), Uniform>} +) { return } + +// ----- + +// CHECK: {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>} +func @interface_var( + %arg0 : f32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>} ) { return } // ----- // expected-error @+1 {{'spv.interface_var_abi' attribute cannot specify storage class when attaching to a non-scalar value}} func @interface_var( - %arg0 : memref<4xf32> {spv.interface_var_abi = {binding = 0 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}} + %arg0 : memref<4xf32> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1), Uniform>} ) { return } // -----