diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -31,6 +31,25 @@ CPred<"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy()">]>, "LLVM dialect integer">; +def LLVMIntBase : TypeConstraint< + And<[LLVM_Type.predicate, + CPred<"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy()">]>, + "LLVM dialect integer">; + +// Integer type of a specific width. +class LLVMI + : Type().isIntegerTy(" # width # ")">]>, + "LLVM dialect " # width # "-bit integer">, + BuildableType< + "::mlir::LLVM::LLVMType::getIntNTy(" + "$_builder.getContext()->getRegisteredDialect()," + # width # ")">; + +def LLVMI1 : LLVMI<1>; + // Base class for LLVM operations. Defines the interface to the llvm::IRBuilder // used to translate to LLVM IR proper. class LLVM_OpBase traits = []> : diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -72,8 +72,7 @@ // Base class for LLVM terminator operations. All terminator operations have // zero results and an optional list of successors. class LLVM_TerminatorOp traits = []> : - LLVM_Op, - Arguments<(ins Variadic:$args)>, Results<(outs)> { + LLVM_Op { let builders = [ OpBuilder< "Builder *, OperationState &result, " @@ -320,15 +319,10 @@ Arguments<(ins OptionalAttr:$callee, Variadic)>, Results<(outs Variadic)> { + let successors = (successor AnySuccessor:$normalDest, + AnySuccessor:$unwindDest); + let builders = [OpBuilder< - "Builder *b, OperationState &result, ArrayRef tys, " - "FlatSymbolRefAttr callee, ValueRange ops, Block* normal, " - "ValueRange normalOps, Block* unwind, ValueRange unwindOps", - [{ - result.addAttribute("callee", callee); - build(b, result, tys, ops, normal, normalOps, unwind, unwindOps); - }]>, - OpBuilder< "Builder *b, OperationState &result, ArrayRef tys, " "ValueRange ops, Block* normal, " "ValueRange normalOps, Block* unwind, ValueRange unwindOps", @@ -460,19 +454,19 @@ // Terminators. def LLVM_BrOp : LLVM_TerminatorOp<"br", []> { + let successors = (successor AnySuccessor:$dest); let parser = [{ return parseBrOp(parser, result); }]; let printer = [{ printBrOp(p, *this); }]; } def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", []> { - let verifier = [{ - if (getNumSuccessors() != 2) - return emitOpError("expected exactly two successors"); - return success(); - }]; + let arguments = (ins LLVMI1:$condition); + let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); + let parser = [{ return parseCondBrOp(parser, result); }]; let printer = [{ printCondBrOp(p, *this); }]; } -def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []> { +def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []>, + Arguments<(ins Variadic:$args)> { string llvmBuilder = [{ if ($_numOperands != 0) builder.CreateRet($args[0]); diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -41,12 +41,14 @@ ``` }]; - let arguments = (ins - Variadic:$block_arguments - ); + let arguments = (ins); let results = (outs); + let successors = (successor AnySuccessor:$target); + + let verifier = [{ return success(); }]; + let builders = [ OpBuilder< "Builder *, OperationState &state, " @@ -60,12 +62,10 @@ let extraClassDeclaration = [{ /// Returns the branch target block. - Block *getTarget() { return getOperation()->getSuccessor(0); } + Block *getTarget() { return target(); } /// Returns the block arguments. - operand_range getBlockArguments() { - return getOperation()->getSuccessorOperands(0); - } + operand_range getBlockArguments() { return targetOperands(); } }]; let autogenSerialization = 0; @@ -115,12 +115,14 @@ let arguments = (ins SPV_Bool:$condition, - Variadic:$branch_arguments, OptionalAttr:$branch_weights ); let results = (outs); + let successors = (successor AnySuccessor:$trueTarget, + AnySuccessor:$falseTarget); + let builders = [ OpBuilder< "Builder *builder, OperationState &state, Value condition, " diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -232,12 +232,10 @@ ^bb3(%3: tensor<*xf32>): }]; - let arguments = (ins Variadic:$operands); + let successors = (successor AnySuccessor:$dest); - let builders = [OpBuilder< - "Builder *, OperationState &result, Block *dest," - "ValueRange operands = {}", [{ - result.addSuccessor(dest, operands); + let builders = [OpBuilder<"Builder *, OperationState &result, Block *dest", [{ + result.addSuccessor(dest, llvm::None); }]>]; // BranchOp is fully verified by traits. @@ -513,16 +511,8 @@ ... }]; - let arguments = (ins I1:$condition, Variadic:$branchOperands); - - let builders = [OpBuilder< - "Builder *, OperationState &result, Value condition," - "Block *trueDest, ValueRange trueOperands," - "Block *falseDest, ValueRange falseOperands", [{ - result.addOperands(condition); - result.addSuccessor(trueDest, trueOperands); - result.addSuccessor(falseDest, falseOperands); - }]>]; + let arguments = (ins I1:$condition); + let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); // CondBranchOp is fully verified by traits. let verifier = ?; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -185,6 +185,10 @@ class RegionConstraint : Constraint; +// Subclass for constraints on a successor. +class SuccessorConstraint : + Constraint; + // How to use these constraint categories: // // * Use TypeConstraint to specify @@ -1338,6 +1342,16 @@ CPred<"$_self.getBlocks().size() == " # numBlocks>, "region with " # numBlocks # " blocks">; +//===----------------------------------------------------------------------===// +// Successor definitions +//===----------------------------------------------------------------------===// + +class Successor : + SuccessorConstraint; + +// Any successor. +def AnySuccessor : Region, "any successor">; + //===----------------------------------------------------------------------===// // OpTrait definitions //===----------------------------------------------------------------------===// @@ -1530,6 +1544,9 @@ // Marker used to identify the region list for an op. def region; +// Marker used to identify the successor list for an op. +def successor; + // Class for defining a custom builder. // // TableGen generates several generic builders for each op by default (see @@ -1580,6 +1597,9 @@ // The list of regions of the op. Default to 0 regions. dag regions = (region); + // The list of successors of the op. Default to 0 successors. + dag successors = (successor); + // Attribute getters can be added to the op by adding an Attr member // with the name and type of the attribute. E.g., adding int attribute // with name "value" and type "i32": diff --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h --- a/mlir/include/mlir/TableGen/Constraint.h +++ b/mlir/include/mlir/TableGen/Constraint.h @@ -48,7 +48,7 @@ StringRef getDescription() const; // Constraint kind - enum Kind { CK_Attr, CK_Region, CK_Type, CK_Uncategorized }; + enum Kind { CK_Attr, CK_Region, CK_Successor, CK_Type, CK_Uncategorized }; Kind getKind() const { return kind; } diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -19,6 +19,7 @@ #include "mlir/TableGen/Dialect.h" #include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Region.h" +#include "mlir/TableGen/Successor.h" #include "mlir/TableGen/Type.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/SmallVector.h" @@ -138,6 +139,17 @@ // Returns the `index`-th region. const NamedRegion &getRegion(unsigned index) const; + // Successors. + using const_successor_iterator = const NamedSuccessor *; + const_successor_iterator successor_begin() const; + const_successor_iterator successor_end() const; + llvm::iterator_range getSuccessors() const; + + // Returns the number of successors. + unsigned getNumSuccessors() const; + // Returns the `index`-th successor. + const NamedSuccessor &getSuccessor(unsigned index) const; + // Trait. using const_trait_iterator = const OpTrait *; const_trait_iterator trait_begin() const; @@ -193,6 +205,9 @@ // The results of the op. SmallVector results; + // The successors of this op. + SmallVector successors; + // The traits of the op. SmallVector traits; diff --git a/mlir/include/mlir/TableGen/Successor.h b/mlir/include/mlir/TableGen/Successor.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/TableGen/Successor.h @@ -0,0 +1,38 @@ +//===- Successor.h - TableGen successor definitions -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_SUCCESSOR_H_ +#define MLIR_TABLEGEN_SUCCESSOR_H_ + +#include "mlir/Support/LLVM.h" +#include "mlir/TableGen/Constraint.h" + +namespace mlir { +namespace tblgen { + +// Wrapper class providing helper methods for accessing Successor defined in +// TableGen. +class Successor : public Constraint { +public: + using Constraint::Constraint; + + static bool classof(const Constraint *c) { + return c->getKind() == CK_Successor; + } +}; + +// A struct bundling a successor's constraint and its name. +struct NamedSuccessor { + StringRef name; + Successor constraint; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_SUCCESSOR_H_ diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -234,15 +234,14 @@ static LogicalResult verify(InvokeOp op) { if (op.getNumResults() > 1) return op.emitOpError("must have 0 or 1 result"); - if (op.getNumSuccessors() != 2) - return op.emitOpError("must have normal and unwind destinations"); - if (op.getSuccessor(1)->empty()) + Block *unwindDest = op.unwindDest(); + if (unwindDest->empty()) return op.emitError( "must have at least one operation in unwind destination"); // In unwind destination, first operation must be LandingpadOp - if (!isa(op.getSuccessor(1)->front())) + if (!isa(unwindDest->front())) return op.emitError("first operation in unwind destination should be a " "llvm.landingpad operation"); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1036,14 +1036,6 @@ printer.printSuccessorAndUseList(branchOp.getOperation(), /*index=*/0); } -static LogicalResult verify(spirv::BranchOp branchOp) { - auto *op = branchOp.getOperation(); - if (op->getNumSuccessors() != 1) - branchOp.emitOpError("must have exactly one successor"); - - return success(); -} - //===----------------------------------------------------------------------===// // spv.BranchConditionalOp //===----------------------------------------------------------------------===// @@ -1114,10 +1106,6 @@ } static LogicalResult verify(spirv::BranchConditionalOp branchOp) { - auto *op = branchOp.getOperation(); - if (op->getNumSuccessors() != 2) - return branchOp.emitOpError("must have exactly two successors"); - if (auto weights = branchOp.branch_weights()) { if (weights->getValue().size() != 2) { return branchOp.emitOpError("must have exactly two branch weights"); diff --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp --- a/mlir/lib/TableGen/Constraint.cpp +++ b/mlir/lib/TableGen/Constraint.cpp @@ -23,6 +23,8 @@ kind = CK_Attr; } else if (record->isSubClassOf("RegionConstraint")) { kind = CK_Region; + } else if (record->isSubClassOf("SuccessorConstraint")) { + kind = CK_Successor; } else { assert(record->isSubClassOf("Constraint")); } diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -159,6 +159,26 @@ return regions[index]; } +auto tblgen::Operator::successor_begin() const -> const_successor_iterator { + return successors.begin(); +} +auto tblgen::Operator::successor_end() const -> const_successor_iterator { + return successors.end(); +} +auto tblgen::Operator::getSuccessors() const + -> llvm::iterator_range { + return {successor_begin(), successor_end()}; +} + +unsigned tblgen::Operator::getNumSuccessors() const { + return successors.size(); +} + +const tblgen::NamedSuccessor & +tblgen::Operator::getSuccessor(unsigned index) const { + return successors[index]; +} + auto tblgen::Operator::trait_begin() const -> const_trait_iterator { return traits.begin(); } @@ -285,6 +305,24 @@ results.push_back({name, TypeConstraint(resultDef)}); } + // Handle successors + auto *successorsDag = def.getValueAsDag("successors"); + auto *successorsOp = dyn_cast(successorsDag->getOperator()); + if (!successorsOp || successorsOp->getDef()->getName() != "successor") { + PrintFatalError(def.getLoc(), + "'successors' must have 'successor' directive"); + } + + for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) { + auto name = successorsDag->getArgNameStr(i); + auto *successorInit = dyn_cast(successorsDag->getArg(i)); + if (!successorInit) { + PrintFatalError(def.getLoc(), + Twine("undefined kind for successor #") + Twine(i)); + } + successors.push_back({name, Successor(successorInit->getDef())}); + } + // Create list of traits, skipping over duplicates: appending to lists in // tablegen is easy, making them unique less so, so dedupe here. if (auto traitList = def.getValueAsListInit("traits")) { diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -32,7 +32,7 @@ func @wrong_accessor_count() -> () { %true = spv.constant true - // expected-error @+1 {{must have exactly one successor}} + // expected-error @+1 {{incorrect number of successors: expected 1 but found 2}} "spv.Branch"()[^one, ^two] : () -> () ^one: spv.Return @@ -116,7 +116,7 @@ func @wrong_accessor_count() -> () { %true = spv.constant true - // expected-error @+1 {{must have exactly two successors}} + // expected-error @+1 {{incorrect number of successors: expected 2 but found 1}} "spv.BranchConditional"(%true)[^one] : (i1) -> () ^one: spv.Return diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -431,7 +431,7 @@ [(IsNotScalar $attr)]>; def TestBranchOp : TEST_Op<"br", [Terminator]> { - let arguments = (ins Variadic:$operands); + let successors = (successor AnySuccessor:$target); } def AttrSizedOperandOp : TEST_Op<"attr_sized_operands", diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -180,6 +180,9 @@ // Generates getters for named regions. void genNamedRegionGetters(); + // Generates getters for named successors. + void genNamedSuccessorGetters(); + // Generates builder methods for the operation. void genBuilder(); @@ -263,6 +266,10 @@ // The generated code will be attached to `body`. void genRegionVerifier(OpMethodBody &body); + // Generates verify statements for successors in the operation. + // The generated code will be attached to `body`. + void genSuccessorVerifier(OpMethodBody &body); + // Generates the traits used by the object. void genTraits(); @@ -299,6 +306,7 @@ genNamedOperandGetters(); genNamedResultGetters(); genNamedRegionGetters(); + genNamedSuccessorGetters(); genAttrGetters(); genBuilder(); genParser(); @@ -556,6 +564,31 @@ } } +void OpEmitter::genNamedSuccessorGetters() { + unsigned numSuccessors = op.getNumSuccessors(); + for (unsigned i = 0; i < numSuccessors; ++i) { + const NamedSuccessor &successor = op.getSuccessor(i); + if (successor.name.empty()) + continue; + + // Generate the block getter. + auto &m = opClass.newMethod("Block *", successor.name); + m.body() << formatv(" return this->getOperation()->getSuccessor({0});", i); + + // Generate the all-operands getter. + auto &operandsMethod = opClass.newMethod( + "Operation::operand_range", (successor.name + "Operands").str()); + operandsMethod.body() << formatv( + " return this->getOperation()->getSuccessorOperands({0});", i); + + // Generate the individual-operand getter. + auto &operandMethod = opClass.newMethod( + "Value", (successor.name + "Operand").str(), "unsigned index"); + operandMethod.body() << formatv( + " return this->getOperation()->getSuccessorOperand({0}, index);", i); + } +} + static bool canGenerateUnwrappedBuilder(Operator &op) { // If this op does not have native attributes at all, return directly to avoid // redefining builders. @@ -846,8 +879,9 @@ // Generate builder that infers type too. // TODO(jpienaar): Subsume this with general checking if type can be infered // automatically. - // TODO(jpienaar): Expand to handle regions. - if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0) + // TODO(jpienaar): Expand to handle regions and successors. + if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0 && + op.getNumSuccessors() == 0) genInferedTypeCollectiveParamBuilder(); } @@ -959,17 +993,24 @@ ++numAttrs; } } + + /// Insert parameters for the block and operands for each successor. + for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) { + paramList += llvm::formatv(", Block *{0}, ValueRange {0}Operands", + namedSuccessor.name) + .str(); + } } void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, bool isRawValueAttr) { - // Push all operands to the result + // Push all operands to the result. for (int i = 0, e = op.getNumOperands(); i < e; ++i) { body << " " << builderOpState << ".addOperands(" << getArgumentName(op, i) << ");\n"; } - // Push all attributes to the result + // Push all attributes to the result. for (const auto &namedAttr : op.getAttributes()) { auto &attr = namedAttr.attr; if (!attr.isDerivedAttr()) { @@ -1007,11 +1048,16 @@ } } - // Create the correct number of regions + // Create the correct number of regions. if (int numRegions = op.getNumRegions()) { for (int i = 0; i < numRegions; ++i) body << " (void)" << builderOpState << ".addRegion();\n"; } + + // Push all successors to the result. + for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) + body << formatv(" {0}.addSuccessor({1}, {1}Operands);\n", builderOpState, + namedSuccessor.name); } void OpEmitter::genCanonicalizerDecls() { @@ -1205,6 +1251,7 @@ } genRegionVerifier(body); + genSuccessorVerifier(body); if (hasCustomVerify) { FmtContext fctx; @@ -1282,6 +1329,33 @@ } } +void OpEmitter::genSuccessorVerifier(OpMethodBody &body) { + unsigned numSuccessors = op.getNumSuccessors(); + + // Verify this op has the correct number of regions + body << formatv( + " if (this->getOperation()->getNumSuccessors() != {0}) {\n " + "return emitOpError(\"has incorrect number of successors: expected {0} " + "but found \") << this->getOperation()->getNumSuccessors();\n }\n", + numSuccessors); + + for (unsigned i = 0; i < numSuccessors; ++i) { + const auto &successor = op.getSuccessor(i); + + auto getSuccessor = + formatv("this->getOperation()->getSuccessor({0})", i).str(); + auto constraint = tgfmt(successor.constraint.getConditionTemplate(), + &verifyCtx.withSelf(getSuccessor)) + .str(); + + body << formatv(" if (!({0})) {\n " + "return emitOpError(\"successor #{1} ('{2}') failed to " + "verify constraint: {3}\");\n }\n", + constraint, i, successor.name, + successor.constraint.getDescription()); + } +} + void OpEmitter::genTraits() { int numResults = op.getNumResults(); int numVariadicResults = op.getNumVariadicResults(); @@ -1319,7 +1393,9 @@ int numVariadicOperands = op.getNumVariadicOperands(); // Add operand size trait. - if (numVariadicOperands != 0) { + // Note: Successor operands are also included in the operation's operand list, + // so we always need to use VariadicOperands in the presence of successors. + if (numVariadicOperands != 0 || op.getNumSuccessors()) { if (numOperands == numVariadicOperands) opClass.addTrait("OpTrait::VariadicOperands"); else