diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -409,7 +409,8 @@ let arguments = (ins SPV_AddressingModelAttr:$addressing_model, SPV_MemoryModelAttr:$memory_model, - OptionalAttr:$vce_triple + OptionalAttr:$vce_triple, + OptionalAttr:$name ); let results = (outs); @@ -417,10 +418,12 @@ let regions = (region SizedRegion<1>:$body); let builders = [ - OpBuilder<[{OpBuilder &, OperationState &state}]>, + OpBuilder<[{OpBuilder &, OperationState &state, + Optional name = llvm::None}]>, OpBuilder<[{OpBuilder &, OperationState &state, spirv::AddressingModel addressing_model, - spirv::MemoryModel memory_model}]> + spirv::MemoryModel memory_model, + Optional name = llvm::None}]> ]; // We need to ensure the block inside the region is properly terminated; @@ -432,6 +435,13 @@ let autogenSerialization = 0; let extraClassDeclaration = [{ + Optional getName() { + if (auto nameAttr = + getAttrOfType(mlir::SymbolTable::getSymbolAttrName())) + return nameAttr.getValue(); + return llvm::None; + } + static StringRef getVCETripleAttrName() { return "vce_triple"; } Block& getBlock() { diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -2282,24 +2282,37 @@ // spv.module //===----------------------------------------------------------------------===// -void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state) { +void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, + Optional name) { ensureTerminator(*state.addRegion(), builder, state.location); + if (name) + state.attributes.push_back(builder.getNamedAttr( + mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name))); } void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, spirv::AddressingModel addressing_model, - spirv::MemoryModel memory_model) { + spirv::MemoryModel memory_model, + Optional name) { state.addAttribute( "addressing_model", builder.getI32IntegerAttr(static_cast(addressing_model))); state.addAttribute("memory_model", builder.getI32IntegerAttr( static_cast(memory_model))); ensureTerminator(*state.addRegion(), builder, state.location); + if (name) + state.attributes.push_back(builder.getNamedAttr( + mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name))); } static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) { Region *body = state.addRegion(); + // If the name is present, parse it. + StringAttr nameAttr; + parser.parseOptionalSymbolName( + nameAttr, mlir::SymbolTable::getSymbolAttrName(), state.attributes); + // Parse attributes spirv::AddressingModel addrModel; spirv::MemoryModel memoryModel; @@ -2328,13 +2341,19 @@ static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) { printer << spirv::ModuleOp::getOperationName(); + if (Optional name = moduleOp.getName()) { + printer << ' '; + printer.printSymbolName(*name); + } + SmallVector elidedAttrs; printer << " " << spirv::stringifyAddressingModel(moduleOp.addressing_model()) << " " << spirv::stringifyMemoryModel(moduleOp.memory_model()); auto addressingModelAttrName = spirv::attributeName(); auto memoryModelAttrName = spirv::attributeName(); - elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName}); + elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName, + SymbolTable::getSymbolAttrName()}); if (Optional triple = moduleOp.vce_triple()) { printer << " requires " << *triple; diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -372,6 +372,9 @@ // CHECK: spv.module Logical GLSL450 spv.module Logical GLSL450 { } +// Module with a name +// CHECK: spv.module @{{.*}} Logical GLSL450 +spv.module @name Logical GLSL450 { } // Module with (version, capabilities, extensions) triple // CHECK: spv.module Logical GLSL450 requires #spv.vce