diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -361,7 +361,7 @@ def SPV_ModuleOp : SPV_Op<"module", [IsolatedFromAbove, SingleBlockImplicitTerminator<"ModuleEndOp">, - SymbolTable]> { + SymbolTable, Symbol]> { let summary = "The top-level op that defines a SPIR-V module"; let description = [{ @@ -409,7 +409,8 @@ let arguments = (ins SPV_AddressingModelAttr:$addressing_model, SPV_MemoryModelAttr:$memory_model, - OptionalAttr:$vce_triple + OptionalAttr:$vce_triple, + OptionalAttr:$sym_name ); let results = (outs); @@ -417,10 +418,12 @@ let regions = (region SizedRegion<1>:$body); let builders = [ - OpBuilder<[{OpBuilder &, OperationState &state}]>, + OpBuilder<[{OpBuilder &, OperationState &state, + Optional name = llvm::None}]>, OpBuilder<[{OpBuilder &, OperationState &state, spirv::AddressingModel addressing_model, - spirv::MemoryModel memory_model}]> + spirv::MemoryModel memory_model, + Optional name = llvm::None}]> ]; // We need to ensure the block inside the region is properly terminated; @@ -432,6 +435,11 @@ let autogenSerialization = 0; let extraClassDeclaration = [{ + + bool isOptionalSymbol() { return true; } + + Optional getName() { return sym_name(); } + static StringRef getVCETripleAttrName() { return "vce_triple"; } Block& getBlock() { diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -2282,24 +2282,39 @@ // spv.module //===----------------------------------------------------------------------===// -void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state) { +void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, + Optional name) { ensureTerminator(*state.addRegion(), builder, state.location); + if (name) { + state.attributes.append(mlir::SymbolTable::getSymbolAttrName(), + builder.getStringAttr(*name)); + } } void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, - spirv::AddressingModel addressing_model, - spirv::MemoryModel memory_model) { + spirv::AddressingModel addressingModel, + spirv::MemoryModel memoryModel, + Optional name) { state.addAttribute( "addressing_model", - builder.getI32IntegerAttr(static_cast(addressing_model))); + builder.getI32IntegerAttr(static_cast(addressingModel))); state.addAttribute("memory_model", builder.getI32IntegerAttr( - static_cast(memory_model))); + static_cast(memoryModel))); ensureTerminator(*state.addRegion(), builder, state.location); + if (name) { + state.attributes.append(mlir::SymbolTable::getSymbolAttrName(), + builder.getStringAttr(*name)); + } } static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) { Region *body = state.addRegion(); + // If the name is present, parse it. + StringAttr nameAttr; + parser.parseOptionalSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + state.attributes); + // Parse attributes spirv::AddressingModel addrModel; spirv::MemoryModel memoryModel; @@ -2328,13 +2343,19 @@ static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) { printer << spirv::ModuleOp::getOperationName(); + if (Optional name = moduleOp.getName()) { + printer << ' '; + printer.printSymbolName(*name); + } + SmallVector elidedAttrs; printer << " " << spirv::stringifyAddressingModel(moduleOp.addressing_model()) << " " << spirv::stringifyMemoryModel(moduleOp.memory_model()); auto addressingModelAttrName = spirv::attributeName(); auto memoryModelAttrName = spirv::attributeName(); - elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName}); + elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName, + SymbolTable::getSymbolAttrName()}); if (Optional triple = moduleOp.vce_triple()) { printer << " requires " << *triple; diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -372,6 +372,9 @@ // CHECK: spv.module Logical GLSL450 spv.module Logical GLSL450 { } +// Module with a name +// CHECK: spv.module @{{.*}} Logical GLSL450 +spv.module @name Logical GLSL450 { } // Module with (version, capabilities, extensions) triple // CHECK: spv.module Logical GLSL450 requires #spv.vce