diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td @@ -97,6 +97,10 @@ let results = (outs SPV_IntVec4:$result ); + + let assemblyFormat = [{ + $execution_scope $predicate attr-dict `:` type($result) + }]; } // ----- @@ -145,6 +149,8 @@ let builders = [ OpBuilder<[{Builder *builder, OperationState &state, spirv::Scope}]> ]; + + let assemblyFormat = "$execution_scope attr-dict `:` type($result)"; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -162,6 +162,10 @@ let verifier = [{ return verifyMemorySemantics(*this); }]; let autogenSerialization = 0; + + let assemblyFormat = [{ + $execution_scope `,` $memory_scope `,` $memory_semantics attr-dict + }]; } // ----- @@ -319,6 +323,8 @@ let verifier = [{ return verifyMemorySemantics(*this); }]; let autogenSerialization = 0; + + let assemblyFormat = "$memory_scope `,` $memory_semantics attr-dict"; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1468,32 +1468,6 @@ } //===----------------------------------------------------------------------===// -// spv.ControlBarrier -//===----------------------------------------------------------------------===// - -static ParseResult parseControlBarrierOp(OpAsmParser &parser, - OperationState &state) { - spirv::Scope executionScope; - spirv::Scope memoryScope; - spirv::MemorySemantics memorySemantics; - - return failure( - parseEnumAttribute(executionScope, parser, state, - kExecutionScopeAttrName) || - parser.parseComma() || - parseEnumAttribute(memoryScope, parser, state, kMemoryScopeAttrName) || - parser.parseComma() || - parseEnumAttribute(memorySemantics, parser, state)); -} - -static void print(spirv::ControlBarrierOp op, OpAsmPrinter &printer) { - printer << spirv::ControlBarrierOp::getOperationName() << " \"" - << stringifyScope(op.execution_scope()) << "\", \"" - << stringifyScope(op.memory_scope()) << "\", \"" - << stringifyMemorySemantics(op.memory_semantics()) << "\""; -} - -//===----------------------------------------------------------------------===// // spv.EntryPoint //===----------------------------------------------------------------------===// @@ -1916,28 +1890,6 @@ // spv.GroupNonUniformBallotOp //===----------------------------------------------------------------------===// -static ParseResult parseGroupNonUniformBallotOp(OpAsmParser &parser, - OperationState &state) { - spirv::Scope executionScope; - OpAsmParser::OperandType operandInfo; - Type resultType; - IntegerType i1Type = parser.getBuilder().getI1Type(); - if (parseEnumAttribute(executionScope, parser, state, - kExecutionScopeAttrName) || - parser.parseOperand(operandInfo) || parser.parseColonType(resultType) || - parser.resolveOperand(operandInfo, i1Type, state.operands)) - return failure(); - - return parser.addTypeToList(resultType, state.types); -} - -static void print(spirv::GroupNonUniformBallotOp ballotOp, - OpAsmPrinter &printer) { - printer << spirv::GroupNonUniformBallotOp::getOperationName() << " \"" - << stringifyScope(ballotOp.execution_scope()) << "\" " - << ballotOp.predicate() << " : " << ballotOp.getType(); -} - static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) { // TODO(antiagainst): check the result integer type's signedness bit is 0. @@ -1959,25 +1911,6 @@ build(builder, state, builder->getI1Type(), scope); } -static ParseResult parseGroupNonUniformElectOp(OpAsmParser &parser, - OperationState &state) { - spirv::Scope executionScope; - Type resultType; - if (parseEnumAttribute(executionScope, parser, state, - kExecutionScopeAttrName) || - parser.parseColonType(resultType)) - return failure(); - - return parser.addTypeToList(resultType, state.types); -} - -static void print(spirv::GroupNonUniformElectOp groupOp, - OpAsmPrinter &printer) { - printer << spirv::GroupNonUniformElectOp::getOperationName() << " \"" - << stringifyScope(groupOp.execution_scope()) - << "\" : " << groupOp.getType(); -} - static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) { spirv::Scope scope = groupOp.execution_scope(); if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) @@ -1987,8 +1920,6 @@ return success(); } - - //===----------------------------------------------------------------------===// // spv.IAdd //===----------------------------------------------------------------------===// @@ -2297,27 +2228,6 @@ } //===----------------------------------------------------------------------===// -// spv.MemoryBarrier -//===----------------------------------------------------------------------===// - -static ParseResult parseMemoryBarrierOp(OpAsmParser &parser, - OperationState &state) { - spirv::Scope memoryScope; - spirv::MemorySemantics memorySemantics; - - return failure( - parseEnumAttribute(memoryScope, parser, state, kMemoryScopeAttrName) || - parser.parseComma() || - parseEnumAttribute(memorySemantics, parser, state)); -} - -static void print(spirv::MemoryBarrierOp op, OpAsmPrinter &printer) { - printer << spirv::MemoryBarrierOp::getOperationName() << " \"" - << stringifyScope(op.memory_scope()) << "\", \"" - << stringifyMemorySemantics(op.memory_semantics()) << "\""; -} - -//===----------------------------------------------------------------------===// // spv.module //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -289,7 +289,7 @@ // ----- func @control_barrier_1() -> () { - // expected-error @+1 {{invalid scope attribute specification: "Something"}} + // expected-error @+1 {{invalid execution_scope attribute specification: "Something"}} spv.ControlBarrier "Something", "Device", "Acquire|UniformMemory" return } diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -264,6 +264,18 @@ //===----------------------------------------------------------------------===// // Parser Gen +/// Returns if we can format the given attribute as an EnumAttr in the parser +/// format. +static bool canFormatEnumAttr(const NamedAttribute *attr) { + const EnumAttr *enumAttr = dyn_cast(&attr->attr); + if (!enumAttr) + return false; + + // The attribute must have a valid underlying type and a constant builder. + return !enumAttr->getUnderlyingType().empty() && + !enumAttr->getConstBuilderTemplate().empty(); +} + /// The code snippet used to generate a parser call for an attribute. /// /// {0}: The storage type of the attribute. @@ -275,6 +287,30 @@ return failure(); )"; +/// The code snippet used to generate a parser call for an enum attribute. +/// +/// {0}: The name of the attribute. +/// {1}: The c++ namespace for the enum symbolize functions. +/// {2}: The function to symbolize a string of the enum. +/// {3}: The constant builder call to create an attribute of the enum type. +const char *const enumAttrParserCode = R"( + { + StringAttr attrVal; + SmallVector attrStorage; + auto loc = parser.getCurrentLocation(); + if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(), + "{0}", attrStorage)) + return failure(); + + auto attrOptional = {1}::{2}(attrVal.getValue()); + if (!attrOptional) + return parser.emitError(loc, "invalid ") + << "{0} attribute specification: " << attrVal; + + result.addAttribute("{0}", {3}); + } +)"; + /// The code snippet used to generate a parser call for an operand. /// /// {0}: The name of the operand. @@ -383,6 +419,24 @@ } else if (auto *attr = dyn_cast(element.get())) { const NamedAttribute *var = attr->getVar(); + // Check to see if we can parse this as an enum attribute. + if (canFormatEnumAttr(var)) { + const EnumAttr &enumAttr = cast(var->attr); + + // Generate the code for building an attribute for this enum. + std::string attrBuilderStr; + { + llvm::raw_string_ostream os(attrBuilderStr); + os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx, + "attrOptional.getValue()"); + } + + body << formatv(enumAttrParserCode, var->name, + enumAttr.getCppNamespace(), + enumAttr.getStringToSymbolFnName(), attrBuilderStr); + continue; + } + // If this attribute has a buildable type, use that when parsing the // attribute. std::string attrTypeStr; @@ -637,7 +691,15 @@ if (auto *attr = dyn_cast(element.get())) { const NamedAttribute *var = attr->getVar(); - // Elide the attribute type if it is buildable.. + // If we are formatting as a enum, symbolize the attribute as a string. + if (canFormatEnumAttr(var)) { + const EnumAttr &enumAttr = cast(var->attr); + body << " p << \"\\\"\" << " << enumAttr.getSymbolToStringFnName() + << "(" << var->name << "()) << \"\\\"\";\n"; + continue; + } + + // Elide the attribute type if it is buildable. Optional attrType = var->attr.getValueType(); if (attrType && attrType->getBuilderCall()) body << " p.printAttributeWithoutType(" << var->name << "Attr());\n";