diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -31,6 +31,15 @@ /// Gets the SPIR-V correspondence for the standard index type. static Type getIndexType(MLIRContext *context); + + /// Returns the corresponding memory space for memref given a SPIR-V storage + /// class. + static unsigned getMemorySpaceForStorageClass(spirv::StorageClass); + + /// Returns the SPIR-V storage class given a memory space for memref. Return + /// llvm::None if the memory space does not map to any SPIR-V storage class. + static Optional + getStorageClassForMemorySpace(unsigned space); }; /// Base class to define a conversion pattern to lower `SourceOp` into SPIR-V. diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h @@ -28,7 +28,7 @@ /// Gets the InterfaceVarABIAttr given its fields. InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, - StorageClass storageClass, + Optional storageClass, MLIRContext *context); /// Returns the attribute name for specifying entry point information. diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td @@ -32,7 +32,7 @@ def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPIRV_Dialect, [ StructFieldAttr<"descriptor_set", I32Attr>, StructFieldAttr<"binding", I32Attr>, - StructFieldAttr<"storage_class", SPV_StorageClassAttr> + StructFieldAttr<"storage_class", OptionalAttr> ]>; // For entry functions, this attribute specifies information related to entry diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -169,8 +169,11 @@ /// Return true of this is a signless integer or a float type. bool isSignlessIntOrFloat(); - /// Return true of this is an integer(of any signedness) or a float type. + /// Return true if this is an integer (of any signedness) or a float type. bool isIntOrFloat(); + /// Return true if this is an integer (of any signedness), index, or float + /// type. + bool isIntOrIndexOrFloat(); /// Print the current type. void print(raw_ostream &os); diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -349,10 +349,15 @@ if (!gpu::GPUDialect::isKernel(funcOp)) return failure(); + // TODO(antiagainst): we are dictating the ABI by ourselves here; it should be + // specified outside. SmallVector argABI; - for (auto argNum : llvm::seq(0, funcOp.getNumArguments())) { - argABI.push_back(spirv::getInterfaceVarABIAttr( - 0, argNum, spirv::StorageClass::StorageBuffer, rewriter.getContext())); + for (auto argIndex : llvm::seq(0, funcOp.getNumArguments())) { + Optional sc; + if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat()) + sc = spirv::StorageClass::StorageBuffer; + argABI.push_back( + spirv::getInterfaceVarABIAttr(0, argIndex, sc, rewriter.getContext())); } auto entryPointAttr = spirv::lookupEntryPointABI(funcOp); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -61,7 +61,7 @@ BlockAndValueMapping &) const final { // Return true here when inlining into spv.func, spv.selection, and // spv.loop operations. - auto op = dest->getParentOp(); + auto *op = dest->getParentOp(); return isa(op) || isa(op) || isa(op); } @@ -383,7 +383,8 @@ // parseAndVerify does the actual parsing and verification of individual // elements. This is a functor since parsing the last element of the list // (termination condition) needs partial specialization. -template struct parseCommaSeparatedList { +template +struct ParseCommaSeparatedList { Optional> operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const { auto parseVal = parseAndVerify(dialect, parser); @@ -393,7 +394,7 @@ auto numArgs = std::tuple_size>::value; if (numArgs != 0 && failed(parser.parseComma())) return llvm::None; - auto remainingValues = parseCommaSeparatedList{}(dialect, parser); + auto remainingValues = ParseCommaSeparatedList{}(dialect, parser); if (!remainingValues) return llvm::None; return std::tuple_cat(std::tuple(parseVal.getValue()), @@ -403,7 +404,8 @@ // Partial specialization of the function to parse a comma separated list of // specs to parse the last element of the list. -template struct parseCommaSeparatedList { +template +struct ParseCommaSeparatedList { Optional> operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const { if (auto value = parseAndVerify(dialect, parser)) @@ -434,7 +436,7 @@ return Type(); auto value = - parseCommaSeparatedList{}(dialect, parser); if (!value) @@ -597,10 +599,10 @@ if (!decorations.empty()) os << ", "; } - auto each_fn = [&os](spirv::Decoration decoration) { + auto eachFn = [&os](spirv::Decoration decoration) { os << stringifyDecoration(decoration); }; - interleaveComma(decorations, os, each_fn); + interleaveComma(decorations, os, eachFn); os << "]"; } }; @@ -865,39 +867,44 @@ return success(); } -// Verifies the given SPIR-V `attribute` attached to a region's argument or -// result and reports error to the given location if invalid. -static LogicalResult -verifyRegionAttribute(Location loc, NamedAttribute attribute, bool forArg) { +/// Verifies the given SPIR-V `attribute` attached to a value of the given +/// `valueType` is valid. +static LogicalResult verifyRegionAttribute(Location loc, Type valueType, + NamedAttribute attribute) { StringRef symbol = attribute.first.strref(); Attribute attr = attribute.second; if (symbol != spirv::getInterfaceVarABIAttrName()) return emitError(loc, "found unsupported '") - << symbol << "' attribute on region " - << (forArg ? "argument" : "result"); + << symbol << "' attribute on region argument"; - if (!attr.isa()) + auto varABIAttr = attr.dyn_cast(); + if (!varABIAttr) return emitError(loc, "'") << symbol - << "' attribute must be a dictionary attribute containing three " - "32-bit integer attributes: 'descriptor_set', 'binding', and " - "'storage_class'"; + << "' attribute must be a dictionary attribute containing two or " + "three 32-bit integer attributes: 'descriptor_set', 'binding', " + "and optional 'storage_class'"; + if (varABIAttr.storage_class() && !valueType.isIntOrIndexOrFloat()) + return emitError(loc, "'") << symbol + << "' attribute cannot specify storage class " + "when attaching to a non-scalar value"; return success(); } LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op, - unsigned /*regionIndex*/, - unsigned /*argIndex*/, + unsigned regionIndex, + unsigned argIndex, NamedAttribute attribute) { - return verifyRegionAttribute(op->getLoc(), attribute, - /*forArg=*/true); + return verifyRegionAttribute( + op->getLoc(), + op->getRegion(regionIndex).front().getArgument(argIndex).getType(), + attribute); } LogicalResult SPIRVDialect::verifyRegionResultAttribute( Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/, NamedAttribute attribute) { - return verifyRegionAttribute(op->getLoc(), attribute, - /*forArg=*/false); + return op->emitError("cannot attach SPIR-V attributes to region result"); } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -38,6 +38,63 @@ return IntegerType::get(32, context); } +/// Mapping between SPIR-V storage classes to memref memory spaces. +/// +/// Note: memref does not have a defined smenatics for each memory space; it +/// depends on the context where it is used. There are no particular reasons +/// behind the number assigments; we try to follow NVVM conventions and largely +/// give common storage classes a smaller number. The hope is use symbolic +/// memory space representation eventually after memref supports it. +// TODO(antiagainst): swap Generic and StorageBuffer assignment to be more akin +// to NVVM. +#define STORAGE_SPACE_MAP_LIST(MAP_FN) \ + MAP_FN(spirv::StorageClass::Generic, 1) \ + MAP_FN(spirv::StorageClass::StorageBuffer, 0) \ + MAP_FN(spirv::StorageClass::Workgroup, 3) \ + MAP_FN(spirv::StorageClass::Uniform, 4) \ + MAP_FN(spirv::StorageClass::Private, 5) \ + MAP_FN(spirv::StorageClass::Function, 6) \ + MAP_FN(spirv::StorageClass::PushConstant, 7) \ + MAP_FN(spirv::StorageClass::UniformConstant, 8) \ + MAP_FN(spirv::StorageClass::Input, 9) \ + MAP_FN(spirv::StorageClass::Output, 10) \ + MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \ + MAP_FN(spirv::StorageClass::AtomicCounter, 12) \ + MAP_FN(spirv::StorageClass::Image, 13) \ + MAP_FN(spirv::StorageClass::CallableDataNV, 14) \ + MAP_FN(spirv::StorageClass::IncomingCallableDataNV, 15) \ + MAP_FN(spirv::StorageClass::RayPayloadNV, 16) \ + MAP_FN(spirv::StorageClass::HitAttributeNV, 17) \ + MAP_FN(spirv::StorageClass::IncomingRayPayloadNV, 18) \ + MAP_FN(spirv::StorageClass::ShaderRecordBufferNV, 19) \ + MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) + +unsigned +SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) { +#define STORAGE_SPACE_MAP_FN(storage, space) \ + case storage: \ + return space; + + switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) } +#undef STORAGE_SPACE_MAP_FN +} + +Optional +SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) { +#define STORAGE_SPACE_MAP_FN(storage, space) \ + case space: \ + return storage; + + switch (space) { + STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) + default: + return llvm::None; + } +#undef STORAGE_SPACE_MAP_FN +} + +#undef STORAGE_SPACE_MAP_LIST + // TODO(ravishankarm): This is a utility function that should probably be // exposed by the SPIR-V dialect. Keeping it local till the use case arises. static Optional getTypeNumBytes(Type t) { @@ -110,14 +167,6 @@ return SPIRVTypeConverter::getIndexType(indexType.getContext()); }); addConversion([this](MemRefType memRefType) -> Type { - // TODO(ravishankarm): For now only support default memory space. The memory - // space description is not set is stone within MLIR, i.e. it depends on the - // context it is being used. To map this to SPIR-V storage classes, we - // should rely on the ABI attributes, and not on the memory space. This is - // still evolving, and needs to be revisited when there is more clarity. - if (memRefType.getMemorySpace()) - return Type(); - auto elementType = convertType(memRefType.getElementType()); if (!elementType) return Type(); @@ -135,11 +184,12 @@ auto arrayType = spirv::ArrayType::get( elementType, arraySize.getValue() / elementSize.getValue(), elementSize.getValue()); + + // Wrap in a struct to satisfy Vulkan interface requirements. auto structType = spirv::StructType::get(arrayType, 0); - // For now initialize the storage class to StorageBuffer. This will be - // updated later based on whats passed in w.r.t to the ABI attributes. - return spirv::PointerType::get(structType, - spirv::StorageClass::StorageBuffer); + if (auto sc = getStorageClassForMemorySpace(memRefType.getMemorySpace())) + return spirv::PointerType::get(structType, *sc); + return Type(); } return Type(); }); diff --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp --- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp @@ -21,13 +21,16 @@ spirv::InterfaceVarABIAttr spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, - spirv::StorageClass storageClass, + Optional storageClass, MLIRContext *context) { Type i32Type = IntegerType::get(32, context); + auto scAttr = + storageClass + ? IntegerAttr::get(i32Type, static_cast(*storageClass)) + : IntegerAttr(); return spirv::InterfaceVarABIAttr::get( IntegerAttr::get(i32Type, descriptorSet), - IntegerAttr::get(i32Type, binding), - IntegerAttr::get(i32Type, static_cast(storageClass)), context); + IntegerAttr::get(i32Type, binding), scAttr, context); } StringRef spirv::getEntryPointABIAttrName() { return "spv.entry_point_abi"; } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -29,25 +29,25 @@ /// Creates a global variable for an argument based on the ABI info. static spirv::GlobalVariableOp -createGlobalVariableForArg(spirv::FuncOp funcOp, OpBuilder &builder, - unsigned argNum, - spirv::InterfaceVarABIAttr abiInfo) { +createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, + unsigned argIndex, + spirv::InterfaceVarABIAttr abiInfo) { auto spirvModule = funcOp.getParentOfType(); - if (!spirvModule) { + if (!spirvModule) return nullptr; - } + OpBuilder::InsertionGuard moduleInsertionGuard(builder); builder.setInsertionPoint(funcOp.getOperation()); std::string varName = - funcOp.getName().str() + "_arg_" + std::to_string(argNum); + funcOp.getName().str() + "_arg_" + std::to_string(argIndex); // Get the type of variable. If this is a scalar/vector type and has an ABI - // info create a variable of type !spv.ptr>. If not + // info create a variable of type !spv.ptr>. If not // it must already be a !spv.ptr>. - auto varType = funcOp.getType().getInput(argNum); - auto storageClass = - static_cast(abiInfo.storage_class().getInt()); + auto varType = funcOp.getType().getInput(argIndex); if (isScalarOrVectorType(varType)) { + auto storageClass = + static_cast(abiInfo.storage_class().getInt()); varType = spirv::PointerType::get(spirv::StructType::get(varType), storageClass); } @@ -84,9 +84,18 @@ funcOp.walk([&](spirv::AddressOfOp addressOfOp) { auto var = module.lookupSymbol(addressOfOp.variable()); - if (var.type().cast().getStorageClass() != - spirv::StorageClass::StorageBuffer) { + // TODO(antiagainst): Per SPIR-V spec: "Before version 1.4, the interface’s + // storage classes are limited to the Input and Output storage classes. + // Starting with version 1.4, the interface’s storage classes are all + // storage classes used in declaring all global variables referenced by the + // entry point’s call tree." We should consider the target environment here. + switch (var.type().cast().getStorageClass()) { + case spirv::StorageClass::Input: + case spirv::StorageClass::Output: interfaceVarSet.insert(var.getOperation()); + break; + default: + break; } }); for (auto &var : interfaceVarSet) { @@ -173,11 +182,10 @@ // produce an error. return failure(); } - auto var = - createGlobalVariableForArg(funcOp, rewriter, argType.index(), abiInfo); - if (!var) { + spirv::GlobalVariableOp var = createGlobalVarForEntryPointArgument( + rewriter, funcOp, argType.index(), abiInfo); + if (!var) return failure(); - } OpBuilder::InsertionGuard funcInsertionGuard(rewriter); rewriter.setInsertionPointToStart(&funcOp.front()); diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -86,6 +86,8 @@ bool Type::isIntOrFloat() { return isa() || isa(); } +bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); } + //===----------------------------------------------------------------------===// // Integer Type //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir --- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir @@ -21,9 +21,9 @@ // CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr, Input> // CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr, Input> // CHECK-LABEL: spv.func @load_store_kernel - // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} - // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} - // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} + // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32{{[}][}]}} + // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32{{[}][}]}} + // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32{{[}][}]}} // CHECK-SAME: [[ARG3:%.*]]: i32 {spv.interface_var_abi = {binding = 3 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} // CHECK-SAME: [[ARG4:%.*]]: i32 {spv.interface_var_abi = {binding = 4 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} // CHECK-SAME: [[ARG5:%.*]]: i32 {spv.interface_var_abi = {binding = 5 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} diff --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/simple.mlir --- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/simple.mlir @@ -5,7 +5,7 @@ // CHECK: spv.module Logical GLSL450 { // CHECK-LABEL: spv.func @basic_module_structure // CHECK-SAME: {{%.*}}: f32 {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} - // CHECK-SAME: {{%.*}}: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} + // CHECK-SAME: {{%.*}}: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32{{[}][}]}} // CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>} gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32>) attributes {gpu.kernel, spv.entry_point_abi = {local_size = dense<[32, 4, 1]>: vector<3xi32>}} { diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -313,6 +313,22 @@ return } +// CHECK-LABEL: func @memref_mem_space +// CHECK-SAME: StorageBuffer +// CHECK-SAME: Uniform +// CHECK-SAME: Workgroup +// CHECK-SAME: PushConstant +// CHECK-SAME: Private +// CHECK-SAME: Function +func @memref_mem_space( + %arg0: memref<4xf32, 0>, + %arg1: memref<4xf32, 4>, + %arg2: memref<4xf32, 3>, + %arg3: memref<4xf32, 7>, + %arg4: memref<4xf32, 5>, + %arg5: memref<4xf32, 6> +) { return } + // CHECK-LABEL: @load_store_zero_rank_float // CHECK: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, // CHECK: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer>) diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-simple.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir rename from mlir/test/Dialect/SPIRV/Transforms/abi-simple.mlir rename to mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/abi-simple.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir @@ -5,14 +5,14 @@ // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr, StorageBuffer> // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr [0]>, StorageBuffer> // CHECK: spv.func [[FN:@.*]]() - spv.func @kernel(%arg0: f32 - {spv.interface_var_abi = {binding = 0 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}}, - %arg1: !spv.ptr>, StorageBuffer> - {spv.interface_var_abi = {binding = 1 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}}) "None" + spv.func @kernel( + %arg0: f32 + {spv.interface_var_abi = {binding = 0 : i32, + descriptor_set = 0 : i32, + storage_class = 12 : i32}}, + %arg1: !spv.ptr>, StorageBuffer> + {spv.interface_var_abi = {binding = 1 : i32, + descriptor_set = 0 : i32}}) "None" attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} { // CHECK: [[ARG1:%.*]] = spv._address_of [[VAR1]] // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]] diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir @@ -21,16 +21,13 @@ spv.func @load_store_kernel( %arg0: !spv.ptr>>, StorageBuffer> {spv.interface_var_abi = {binding = 0 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}}, + descriptor_set = 0 : i32}}, %arg1: !spv.ptr>>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}}, + descriptor_set = 0 : i32}}, %arg2: !spv.ptr>>, StorageBuffer> {spv.interface_var_abi = {binding = 2 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}}, + descriptor_set = 0 : i32}}, %arg3: i32 {spv.interface_var_abi = {binding = 3 : i32, descriptor_set = 0 : i32, diff --git a/mlir/test/Dialect/SPIRV/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/target-and-abi.mlir --- a/mlir/test/Dialect/SPIRV/target-and-abi.mlir +++ b/mlir/test/Dialect/SPIRV/target-and-abi.mlir @@ -14,7 +14,7 @@ // ----- -// expected-error @+1 {{found unsupported 'spv.something' attribute on region result}} +// expected-error @+1 {{cannot attach SPIR-V attributes to region result}} func @unknown_attr_on_region() -> (i32 {spv.something}) { %0 = constant 10.0 : f32 return %0: f32 @@ -51,14 +51,14 @@ // spv.interface_var_abi //===----------------------------------------------------------------------===// -// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}} +// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing two or three 32-bit integer attributes: 'descriptor_set', 'binding', and optional 'storage_class'}} func @interface_var( %arg0 : f32 {spv.interface_var_abi = 64} ) { return } // ----- -// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}} +// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing two or three 32-bit integer attributes: 'descriptor_set', 'binding', and optional 'storage_class'}} func @interface_var( %arg0 : f32 {spv.interface_var_abi = {binding = 0: i32}} ) { return } @@ -74,31 +74,12 @@ // ----- -// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}} -func @interface_var() -> (f32 {spv.interface_var_abi = 64}) -{ - %0 = constant 10.0 : f32 - return %0: f32 -} - -// ----- - -// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}} -func @interface_var() -> (f32 {spv.interface_var_abi = {binding = 0: i32}}) -{ - %0 = constant 10.0 : f32 - return %0: f32 -} - -// ----- - -// CHECK: {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32}} -func @interface_var() -> (f32 {spv.interface_var_abi = { - binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32}}) -{ - %0 = constant 10.0 : f32 - return %0: f32 -} +// expected-error @+1 {{'spv.interface_var_abi' attribute cannot specify storage class when attaching to a non-scalar value}} +func @interface_var( + %arg0 : memref<4xf32> {spv.interface_var_abi = {binding = 0 : i32, + descriptor_set = 0 : i32, + storage_class = 12 : i32}} +) { return } // -----