diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -607,6 +607,10 @@ - Represents all of the results of an operation. +* `successors` + + - Represents all of the successors of an operation. + * `type` ( input ) - Represents the type of the given input. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -455,15 +455,12 @@ // Terminators. def LLVM_BrOp : LLVM_TerminatorOp<"br", []> { let successors = (successor AnySuccessor:$dest); - let parser = [{ return parseBrOp(parser, result); }]; - let printer = [{ printBrOp(p, *this); }]; + let assemblyFormat = "$dest attr-dict"; } def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", []> { let arguments = (ins LLVMI1:$condition); let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); - - let parser = [{ return parseCondBrOp(parser, result); }]; - let printer = [{ printCondBrOp(p, *this); }]; + let assemblyFormat = "$condition `,` successors attr-dict"; } def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []>, Arguments<(ins Variadic:$args)> { diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -69,6 +69,8 @@ }]; let autogenSerialization = 0; + + let assemblyFormat = "successors attr-dict"; } // ----- diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -250,6 +250,7 @@ }]; let hasCanonicalizer = 1; + let assemblyFormat = "$dest attr-dict"; } def CallOp : Std_Op<"call", [CallOpInterface]> { @@ -602,6 +603,7 @@ }]; let hasCanonicalizer = 1; + let assemblyFormat = "$condition `,` successors attr-dict"; } def ConstantOp : Std_Op<"constant", diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -578,6 +578,11 @@ virtual ParseResult parseSuccessorAndUseList(Block *&dest, SmallVectorImpl &operands) = 0; + /// Parse an optional operation successor and its operand list. + virtual OptionalParseResult + parseOptionalSuccessorAndUseList(Block *&dest, + SmallVectorImpl &operands) = 0; + //===--------------------------------------------------------------------===// // Type Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -780,69 +780,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::BrOp. -//===----------------------------------------------------------------------===// - -static void printBrOp(OpAsmPrinter &p, BrOp &op) { - p << op.getOperationName() << ' '; - p.printSuccessorAndUseList(op.getOperation(), 0); - p.printOptionalAttrDict(op.getAttrs()); -} - -// ::= `llvm.br` bb-id (`[` ssa-use-and-type-list `]`)? -// attribute-dict? -static ParseResult parseBrOp(OpAsmParser &parser, OperationState &result) { - Block *dest; - SmallVector operands; - if (parser.parseSuccessorAndUseList(dest, operands) || - parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - result.addSuccessor(dest, operands); - return success(); -} - -//===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::CondBrOp. -//===----------------------------------------------------------------------===// - -static void printCondBrOp(OpAsmPrinter &p, CondBrOp &op) { - p << op.getOperationName() << ' ' << op.getOperand(0) << ", "; - p.printSuccessorAndUseList(op.getOperation(), 0); - p << ", "; - p.printSuccessorAndUseList(op.getOperation(), 1); - p.printOptionalAttrDict(op.getAttrs()); -} - -// ::= `llvm.cond_br` ssa-use `,` -// bb-id (`[` ssa-use-and-type-list `]`)? `,` -// bb-id (`[` ssa-use-and-type-list `]`)? attribute-dict? -static ParseResult parseCondBrOp(OpAsmParser &parser, OperationState &result) { - Block *trueDest; - Block *falseDest; - SmallVector trueOperands; - SmallVector falseOperands; - OpAsmParser::OperandType condition; - - Builder &builder = parser.getBuilder(); - auto *llvmDialect = - builder.getContext()->getRegisteredDialect(); - auto i1Type = LLVM::LLVMType::getInt1Ty(llvmDialect); - - if (parser.parseOperand(condition) || parser.parseComma() || - parser.parseSuccessorAndUseList(trueDest, trueOperands) || - parser.parseComma() || - parser.parseSuccessorAndUseList(falseDest, falseOperands) || - parser.parseOptionalAttrDict(result.attributes) || - parser.resolveOperand(condition, i1Type, result.operands)) - return failure(); - - result.addSuccessor(trueDest, trueOperands); - result.addSuccessor(falseDest, falseOperands); - return success(); -} - //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::ReturnOp. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1018,24 +1018,6 @@ results.insert(context); } -//===----------------------------------------------------------------------===// -// spv.BranchOp -//===----------------------------------------------------------------------===// - -static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &state) { - Block *dest; - SmallVector destOperands; - if (parser.parseSuccessorAndUseList(dest, destOperands)) - return failure(); - state.addSuccessor(dest, destOperands); - return success(); -} - -static void print(spirv::BranchOp branchOp, OpAsmPrinter &printer) { - printer << spirv::BranchOp::getOperationName() << ' '; - printer.printSuccessorAndUseList(branchOp.getOperation(), /*index=*/0); -} - //===----------------------------------------------------------------------===// // spv.BranchConditionalOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -414,20 +414,6 @@ }; } // end anonymous namespace. -static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) { - Block *dest; - SmallVector destOperands; - if (parser.parseSuccessorAndUseList(dest, destOperands)) - return failure(); - result.addSuccessor(dest, destOperands); - return success(); -} - -static void print(OpAsmPrinter &p, BranchOp op) { - p << "br "; - p.printSuccessorAndUseList(op.getOperation(), 0); -} - Block *BranchOp::getDest() { return getSuccessor(0); } void BranchOp::setDest(Block *block) { return setSuccessor(block, 0); } @@ -810,42 +796,6 @@ }; } // end anonymous namespace. -static ParseResult parseCondBranchOp(OpAsmParser &parser, - OperationState &result) { - SmallVector destOperands; - Block *dest; - OpAsmParser::OperandType condInfo; - - // Parse the condition. - Type int1Ty = parser.getBuilder().getI1Type(); - if (parser.parseOperand(condInfo) || parser.parseComma() || - parser.resolveOperand(condInfo, int1Ty, result.operands)) { - return parser.emitError(parser.getNameLoc(), - "expected condition type was boolean (i1)"); - } - - // Parse the true successor. - if (parser.parseSuccessorAndUseList(dest, destOperands)) - return failure(); - result.addSuccessor(dest, destOperands); - - // Parse the false successor. - destOperands.clear(); - if (parser.parseComma() || - parser.parseSuccessorAndUseList(dest, destOperands)) - return failure(); - result.addSuccessor(dest, destOperands); - - return success(); -} - -static void print(OpAsmPrinter &p, CondBranchOp op) { - p << "cond_br " << op.getCondition() << ", "; - p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); - p << ", "; - p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); -} - void CondBranchOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -4332,6 +4332,15 @@ return parser.parseSuccessorAndUseList(dest, operands); } + /// Parse an optional operation successor and its operand list. + OptionalParseResult + parseOptionalSuccessorAndUseList(Block *&dest, + SmallVectorImpl &operands) override { + if (parser.getToken().isNot(Token::caret_identifier)) + return llvm::None; + return parseSuccessorAndUseList(dest, operands); + } + //===--------------------------------------------------------------------===// // Type Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -24,8 +24,8 @@ // ----- func @missing_accessor() -> () { + // expected-error @+1 {{has incorrect number of successors: expected 1 but found 0}} spv.Branch - // expected-error @+1 {{expected block name}} } // ----- diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -394,7 +394,6 @@ ^bb0: %a = "foo"() : () -> i32 // expected-note {{prior use here}} cond_br %a, ^bb0, ^bb0 // expected-error {{use of value '%a' expects different type than prior uses: 'i1' vs 'i32'}} -// expected-error@-1 {{expected condition type was boolean (i1)}} } // ----- diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -94,10 +94,18 @@ // results // CHECK: error: 'results' directive can not be used as a top-level directive -def DirectiveResultsInvalidA : TestFormat_Op<"operands_invalid_a", [{ +def DirectiveResultsInvalidA : TestFormat_Op<"results_invalid_a", [{ results }]>; +//===----------------------------------------------------------------------===// +// successors + +// CHECK: error: 'successors' is only valid as a top-level directive +def DirectiveSuccessorsInvalidA : TestFormat_Op<"successors_invalid_a", [{ + type(successors) +}]>; + //===----------------------------------------------------------------------===// // type @@ -231,7 +239,7 @@ // Variables //===----------------------------------------------------------------------===// -// CHECK: error: expected variable to refer to a argument or result +// CHECK: error: expected variable to refer to a argument, result, or successor def VariableInvalidA : TestFormat_Op<"variable_invalid_a", [{ $unknown_arg attr-dict }]>; @@ -251,6 +259,18 @@ def VariableInvalidE : TestFormat_Op<"variable_invalid_e", [{ $result attr-dict }]>, Results<(outs I64:$result)>; +// CHECK: error: successor 'successor' is already bound +def VariableInvalidF : TestFormat_Op<"variable_invalid_f", [{ + $successor $successor attr-dict +}]> { + let successors = (successor AnySuccessor:$successor); +} +// CHECK: error: successor 'successor' is already bound +def VariableInvalidG : TestFormat_Op<"variable_invalid_g", [{ + successors $successor attr-dict +}]> { + let successors = (successor AnySuccessor:$successor); +} //===----------------------------------------------------------------------===// // Coverage Checks diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -49,6 +49,7 @@ FunctionalTypeDirective, OperandsDirective, ResultsDirective, + SuccessorsDirective, TypeDirective, /// This element is a literal. @@ -58,6 +59,7 @@ AttributeVariable, OperandVariable, ResultVariable, + SuccessorVariable, /// This element is an optional element. Optional, @@ -105,6 +107,10 @@ /// This class represents a variable that refers to a result. using ResultVariable = VariableElement; + +/// This class represents a variable that refers to a successor. +using SuccessorVariable = + VariableElement; } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -126,6 +132,11 @@ /// all of the results of an operation. using ResultsDirective = DirectiveElement; +/// This class represents the `successors` directive. This directive represents +/// all of the successors of an operation. +using SuccessorsDirective = + DirectiveElement; + /// This class represents the `attr-dict` directive. This directive represents /// the attribute dictionary of the operation. class AttrDictDirective @@ -294,6 +305,8 @@ /// Generate the c++ to resolve the types of operands and results during /// parsing. void genParserTypeResolution(Operator &op, OpMethodBody &body); + /// Generate the c++ to resolve successors during parsing. + void genParserSuccessorResolution(Operator &op, OpMethodBody &body); /// Generate the operation printer from this format. void genPrinter(Operator &op, OpClass &opClass); @@ -403,6 +416,43 @@ {1}Types = {0}__{1}_functionType.getResults(); )"; +/// The code snippet used to generate a parser call for a successor list. +/// +/// {0}: The name for the successor list. +const char *successorListParserCode = R"( + SmallVector>, 2> {0}; + { + Block *succ; + SmallVector succOperands; + // Parse the first successor. + auto firstSucc = parser.parseOptionalSuccessorAndUseList(succ, + succOperands); + if (firstSucc.hasValue()) { + if (failed(*firstSucc)) + return failure(); + {0}.emplace_back(succ, succOperands); + + // Parse any trailing successors. + while (succeeded(parser.parseOptionalComma())) { + succOperands.clear(); + if (parser.parseSuccessorAndUseList(succ, succOperands)) + return failure(); + {0}.emplace_back(succ, succOperands); + } + } + } +)"; + +/// The code snippet used to generate a parser call for a successor. +/// +/// {0}: The name of the successor. +const char *successorParserCode = R"( + Block *{0}Successor = nullptr; + SmallVector {0}Operands; + if (parser.parseSuccessorAndUseList({0}Successor, {0}Operands)) + return failure(); +)"; + /// Get the name used for the type list for the given type directive operand. /// 'isVariadic' is set to true if the operand has variadic types. static StringRef getTypeListName(Element *arg, bool &isVariadic) { @@ -539,6 +589,8 @@ bool isVariadic = operand->getVar()->isVariadic(); body << formatv(isVariadic ? variadicOperandParserCode : operandParserCode, operand->getVar()->name); + } else if (auto *successor = dyn_cast(element)) { + body << formatv(successorParserCode, successor->getVar()->name); /// Directives. } else if (auto *attrDict = dyn_cast(element)) { @@ -551,6 +603,8 @@ << " SmallVector allOperands;\n" << " if (parser.parseOperandList(allOperands))\n" << " return failure();\n"; + } else if (isa(element)) { + body << llvm::formatv(successorListParserCode, "fullSuccessors"); } else if (auto *dir = dyn_cast(element)) { bool isVariadic = false; StringRef listName = getTypeListName(dir->getOperand(), isVariadic); @@ -586,9 +640,10 @@ for (auto &element : elements) genElementParser(element.get(), body, attrTypeCtx); - // Generate the code to resolve the operand and result types now that they - // have been parsed. + // Generate the code to resolve the operand/result types and successors now + // that they have been parsed. genParserTypeResolution(op, body); + genParserSuccessorResolution(op, body); body << " return success();\n"; } @@ -730,6 +785,23 @@ } } +void OperationFormat::genParserSuccessorResolution(Operator &op, + OpMethodBody &body) { + // Check for the case where all successors were parsed. + bool hasAllSuccessors = llvm::any_of( + elements, [](auto &elt) { return isa(elt.get()); }); + if (hasAllSuccessors) { + body << " for (auto &succAndArgs : fullSuccessors)\n" + << " result.addSuccessor(succAndArgs.first, succAndArgs.second);\n"; + return; + } + + // Otherwise, handle each successor individually. + for (const NamedSuccessor &successor : op.getSuccessors()) + body << llvm::formatv(" result.addSuccessor({0}Successor, {0}Operands);\n", + successor.name); +} + //===----------------------------------------------------------------------===// // PrinterGen @@ -790,8 +862,8 @@ /// Generate the code for printing the given element. static void genElementPrinter(Element *element, OpMethodBody &body, - OperationFormat &fmt, bool &shouldEmitSpace, - bool &lastWasPunctuation) { + OperationFormat &fmt, Operator &op, + bool &shouldEmitSpace, bool &lastWasPunctuation) { if (LiteralElement *literal = dyn_cast(element)) return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace, lastWasPunctuation); @@ -808,7 +880,7 @@ // Emit each of the elements. for (Element &childElement : optional->getElements()) - genElementPrinter(&childElement, body, fmt, shouldEmitSpace, + genElementPrinter(&childElement, body, fmt, op, shouldEmitSpace, lastWasPunctuation); body << " }\n"; return; @@ -847,8 +919,16 @@ body << " p.printAttribute(" << var->name << "Attr());\n"; } else if (auto *operand = dyn_cast(element)) { body << " p << " << operand->getVar()->name << "();\n"; + } else if (auto *successor = dyn_cast(element)) { + unsigned index = successor->getVar() - op.successor_begin(); + body << " p.printSuccessorAndUseList(*this, " << index << ");\n"; } else if (isa(element)) { body << " p << getOperation()->getOperands();\n"; + } else if (isa(element)) { + body << " interleaveComma(llvm::seq(0, " + "getOperation()->getNumSuccessors()), p, [&](int i) {" + << " p.printSuccessorAndUseList(*this, i);" + << " });\n"; } else if (auto *dir = dyn_cast(element)) { body << " p << "; genTypeOperandPrinter(dir->getOperand(), body) << ";\n"; @@ -879,7 +959,7 @@ // punctuation. bool shouldEmitSpace = true, lastWasPunctuation = false; for (auto &element : elements) - genElementPrinter(element.get(), body, *this, shouldEmitSpace, + genElementPrinter(element.get(), body, *this, op, shouldEmitSpace, lastWasPunctuation); } @@ -910,6 +990,7 @@ kw_functional_type, kw_operands, kw_results, + kw_successors, kw_type, keyword_end, @@ -1091,6 +1172,7 @@ .Case("functional-type", Token::kw_functional_type) .Case("operands", Token::kw_operands) .Case("results", Token::kw_results) + .Case("successors", Token::kw_successors) .Case("type", Token::kw_type) .Default(Token::identifier); return Token(kind, str); @@ -1164,6 +1246,8 @@ llvm::SMLoc loc, bool isTopLevel); LogicalResult parseResultsDirective(std::unique_ptr &element, llvm::SMLoc loc, bool isTopLevel); + LogicalResult parseSuccessorsDirective(std::unique_ptr &element, + llvm::SMLoc loc, bool isTopLevel); LogicalResult parseTypeDirective(std::unique_ptr &element, Token tok, bool isTopLevel); LogicalResult parseTypeDirectiveOperand(std::unique_ptr &element); @@ -1202,9 +1286,11 @@ // The following are various bits of format state used for verification // during parsing. bool hasAllOperands = false, hasAttrDict = false; + bool hasAllSuccessors = false; llvm::SmallBitVector seenOperandTypes, seenResultTypes; llvm::DenseSet seenOperands; llvm::DenseSet seenAttrs; + llvm::DenseSet seenSuccessors; llvm::DenseSet optionalVariables; }; } // end anonymous namespace @@ -1305,6 +1391,17 @@ auto it = buildableTypes.insert({*builder, buildableTypes.size()}); fmt.operandTypes[i].setBuilderIdx(it.first->second); } + + // Check that all of the successors are within the format. + if (!hasAllSuccessors) { + for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) { + const NamedSuccessor &successor = op.getSuccessor(i); + if (!seenSuccessors.count(&successor)) { + return emitError(loc, "format missing instance of successor #" + + Twine(i) + "('" + successor.name + "')"); + } + } + } return success(); } @@ -1411,7 +1508,17 @@ element = std::make_unique(result); return success(); } - return emitError(loc, "expected variable to refer to a argument or result"); + /// Successors. + if (const auto *successor = findArg(op.getSuccessors(), name)) { + if (!isTopLevel) + return emitError(loc, "successors can only be used at the top level"); + if (hasAllSuccessors || !seenSuccessors.insert(successor).second) + return emitError(loc, "successor '" + name + "' is already bound"); + element = std::make_unique(successor); + return success(); + } + return emitError( + loc, "expected variable to refer to a argument, result, or successor"); } LogicalResult FormatParser::parseDirective(std::unique_ptr &element, @@ -1432,6 +1539,8 @@ return parseOperandsDirective(element, dirTok.getLoc(), isTopLevel); case Token::kw_results: return parseResultsDirective(element, dirTok.getLoc(), isTopLevel); + case Token::kw_successors: + return parseSuccessorsDirective(element, dirTok.getLoc(), isTopLevel); case Token::kw_type: return parseTypeDirective(element, dirTok, isTopLevel); @@ -1616,6 +1725,19 @@ return success(); } +LogicalResult +FormatParser::parseSuccessorsDirective(std::unique_ptr &element, + llvm::SMLoc loc, bool isTopLevel) { + if (!isTopLevel) + return emitError(loc, + "'successors' is only valid as a top-level directive"); + if (hasAllSuccessors || !seenSuccessors.empty()) + return emitError(loc, "'successors' directive creates overlap in format"); + hasAllSuccessors = true; + element = std::make_unique(); + return success(); +} + LogicalResult FormatParser::parseTypeDirective(std::unique_ptr &element, Token tok, bool isTopLevel) {