diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -381,6 +381,10 @@ LogicalResult verifyResultsAreFloatLike(Operation *op); LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op); LogicalResult verifyIsTerminator(Operation *op); +LogicalResult verifyZeroSuccessor(Operation *op); +LogicalResult verifyOneSuccessor(Operation *op); +LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors); +LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors); LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName); LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName); } // namespace impl @@ -410,6 +414,9 @@ } }; +//===----------------------------------------------------------------------===// +// Operand Traits + namespace detail { /// Utility trait base that provides accessors for derived traits that have /// multiple operands. @@ -522,6 +529,9 @@ class VariadicOperands : public detail::MultiOperandTraitBase {}; +//===----------------------------------------------------------------------===// +// Result Traits + /// This class provides return value APIs for ops that are known to have /// zero results. template @@ -644,6 +654,123 @@ class VariadicResults : public detail::MultiResultTraitBase {}; +//===----------------------------------------------------------------------===// +// Terminator Traits + +/// This class provides the API for ops that are known to be terminators. +template +class IsTerminator : public TraitBase { +public: + static AbstractOperation::OperationProperties getTraitProperties() { + return static_cast( + OperationProperty::Terminator); + } + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyIsTerminator(op); + } + + unsigned getNumSuccessorOperands(unsigned index) { + return this->getOperation()->getNumSuccessorOperands(index); + } +}; + +/// This class provides verification for ops that are known to have zero +/// successors. +template +class ZeroSuccessor : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyZeroSuccessor(op); + } +}; + +namespace detail { +/// Utility trait base that provides accessors for derived traits that have +/// multiple successors. +template class TraitType> +struct MultiSuccessorTraitBase : public TraitBase { + using succ_iterator = Operation::succ_iterator; + using succ_range = SuccessorRange; + + /// Return the number of successors. + unsigned getNumSuccessors() { + return this->getOperation()->getNumSuccessors(); + } + + /// Return the successor at `index`. + Block *getSuccessor(unsigned i) { + return this->getOperation()->getSuccessor(i); + } + + /// Set the successor at `index`. + void setSuccessor(Block *block, unsigned i) { + return this->getOperation()->setSuccessor(block, i); + } + + /// Successor iterator access. + succ_iterator succ_begin() { return this->getOperation()->succ_begin(); } + succ_iterator succ_end() { return this->getOperation()->succ_end(); } + succ_range getSuccessors() { return this->getOperation()->getSuccessors(); } +}; +} // end namespace detail + +/// This class provides APIs for ops that are known to have a single successor. +template +class OneSuccessor : public TraitBase { +public: + Block *getSuccessor() { return this->getOperation()->getSuccessor(0); } + void setSuccessor(Block *succ) { + this->getOperation()->setSuccessor(succ, 0); + } + + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyOneSuccessor(op); + } +}; + +/// This class provides the API for ops that are known to have a specified +/// number of successors. +template +class NSuccessors { +public: + static_assert(N > 1, "use ZeroSuccessor/OneSuccessor for N < 2"); + + template + class Impl : public detail::MultiSuccessorTraitBase::Impl> { + public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyNSuccessors(op, N); + } + }; +}; + +/// This class provides APIs for ops that are known to have at least a specified +/// number of successors. +template +class AtLeastNSuccessors { +public: + template + class Impl + : public detail::MultiSuccessorTraitBase::Impl> { + public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyAtLeastNSuccessors(op, N); + } + }; +}; + +/// This class provides the API for ops which have an unknown number of +/// successors. +template +class VariadicSuccessors + : public detail::MultiSuccessorTraitBase { +}; + +//===----------------------------------------------------------------------===// +// Misc Traits + /// This class provides verification for ops that are known to have the same /// operand shape: all operands are scalars, vectors/tensors of the same /// shape. @@ -789,41 +916,6 @@ } }; -/// This class provides the API for ops that are known to be terminators. -template -class IsTerminator : public TraitBase { -public: - static AbstractOperation::OperationProperties getTraitProperties() { - return static_cast( - OperationProperty::Terminator); - } - static LogicalResult verifyTrait(Operation *op) { - return impl::verifyIsTerminator(op); - } - - unsigned getNumSuccessors() { - return this->getOperation()->getNumSuccessors(); - } - unsigned getNumSuccessorOperands(unsigned index) { - return this->getOperation()->getNumSuccessorOperands(index); - } - - Block *getSuccessor(unsigned index) { - return this->getOperation()->getSuccessor(index); - } - - void setSuccessor(Block *block, unsigned index) { - return this->getOperation()->setSuccessor(block, index); - } - - void addSuccessorOperand(unsigned index, Value value) { - return this->getOperation()->addSuccessorOperand(index, value); - } - void addSuccessorOperands(unsigned index, ArrayRef values) { - return this->getOperation()->addSuccessorOperand(index, values); - } -}; - /// This class provides the API for ops that are known to be isolated from /// above. template diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1895,7 +1895,7 @@ return false; auto branchOp = dyn_cast(srcBlock.back()); - return branchOp && branchOp.getSuccessor(0) == &dstBlock; + return branchOp && branchOp.getSuccessor() == &dstBlock; } static LogicalResult verify(spirv::LoopOp loopOp) { diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -415,9 +415,9 @@ }; } // end anonymous namespace. -Block *BranchOp::getDest() { return getSuccessor(0); } +Block *BranchOp::getDest() { return getSuccessor(); } -void BranchOp::setDest(Block *block) { return setSuccessor(block, 0); } +void BranchOp::setDest(Block *block) { return setSuccessor(block); } void BranchOp::eraseOperand(unsigned index) { getOperation()->eraseSuccessorOperand(0, index); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -942,6 +942,14 @@ return success(); } +LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { + Block *block = op->getBlock(); + // Verify that the operation is at the end of the respective parent block. + if (!block || &block->back() != op) + return op->emitOpError("must be the last operation in the parent block"); + return success(); +} + static LogicalResult verifySuccessor(Operation *op, unsigned succNo) { Operation::operand_range operands = op->getSuccessorOperands(succNo); unsigned operandCount = op->getNumSuccessorOperands(succNo); @@ -976,18 +984,40 @@ return success(); } -LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { - Block *block = op->getBlock(); - // Verify that the operation is at the end of the respective parent block. - if (!block || &block->back() != op) - return op->emitOpError("must be the last operation in the parent block"); - - // Verify the state of the successor blocks. - if (op->getNumSuccessors() != 0 && failed(verifyTerminatorSuccessors(op))) - return failure(); +LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) { + if (op->getNumSuccessors() != 0) { + return op->emitOpError("requires 0 successors but found ") + << op->getNumSuccessors(); + } return success(); } +LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { + if (op->getNumSuccessors() != 1) { + return op->emitOpError("requires 1 successor but found ") + << op->getNumSuccessors(); + } + return verifyTerminatorSuccessors(op); +} +LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, + unsigned numSuccessors) { + if (op->getNumSuccessors() != numSuccessors) { + return op->emitOpError("requires ") + << numSuccessors << " successors but found " + << op->getNumSuccessors(); + } + return verifyTerminatorSuccessors(op); +} +LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, + unsigned numSuccessors) { + if (op->getNumSuccessors() < numSuccessors) { + return op->emitOpError("requires at least ") + << numSuccessors << " successors but found " + << op->getNumSuccessors(); + } + return verifyTerminatorSuccessors(op); +} + LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { for (auto resultType : op->getResultTypes()) { auto elementType = getTensorOrVectorElementType(resultType); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -352,7 +352,7 @@ // Emit branches. We need to look up the remapped blocks and ignore the block // arguments that were transformed into PHI nodes. if (auto brOp = dyn_cast(opInst)) { - builder.CreateBr(blockMapping[brOp.getSuccessor(0)]); + builder.CreateBr(blockMapping[brOp.getSuccessor()]); return success(); } if (auto condbrOp = dyn_cast(opInst)) { diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -24,7 +24,7 @@ // ----- func @missing_accessor() -> () { - // expected-error @+1 {{has incorrect number of successors: expected 1 but found 0}} + // expected-error @+1 {{requires 1 successor but found 0}} spv.Branch } @@ -32,7 +32,7 @@ func @wrong_accessor_count() -> () { %true = spv.constant true - // expected-error @+1 {{incorrect number of successors: expected 1 but found 2}} + // expected-error @+1 {{requires 1 successor but found 2}} "spv.Branch"()[^one, ^two] : () -> () ^one: spv.Return @@ -116,7 +116,7 @@ func @wrong_accessor_count() -> () { %true = spv.constant true - // expected-error @+1 {{incorrect number of successors: expected 2 but found 1}} + // expected-error @+1 {{requires 2 successors but found 1}} "spv.BranchConditional"(%true)[^one] : (i1) -> () ^one: spv.Return diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -54,7 +54,7 @@ // CHECK: ArrayRef tblgen_operands; // CHECK: }; -// CHECK: class AOp : public Op::Impl, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl +// CHECK: class AOp : public Op::Impl, OpTrait::ZeroSuccessor, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl // CHECK: public: // CHECK: using Op::Op; // CHECK: using OperandAdaptor = AOpOperandAdaptor; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1376,26 +1376,8 @@ } void OpEmitter::genSuccessorVerifier(OpMethodBody &body) { - unsigned numSuccessors = op.getNumSuccessors(); - - const char *checkSuccessorSizeCode = R"( - if (this->getOperation()->getNumSuccessors() {0} {1}) { - return emitOpError("has incorrect number of successors: expected{2} {1}" - " but found ") - << this->getOperation()->getNumSuccessors(); - } - )"; - - // Verify this op has the correct number of successors. - unsigned numVariadicSuccessors = op.getNumVariadicSuccessors(); - if (numVariadicSuccessors == 0) { - body << formatv(checkSuccessorSizeCode, "!=", numSuccessors, ""); - } else if (numVariadicSuccessors != numSuccessors) { - body << formatv(checkSuccessorSizeCode, "<", - numSuccessors - numVariadicSuccessors, " at least"); - } - // If we have no successors, there is nothing more to do. + unsigned numSuccessors = op.getNumSuccessors(); if (numSuccessors == 0) return; @@ -1427,31 +1409,44 @@ body << " }\n"; } +/// Add a size count trait to the given operation class. +static void addSizeCountTrait(OpClass &opClass, StringRef traitKind, + int numNonVariadic, int numVariadic) { + if (numVariadic != 0) { + if (numNonVariadic == numVariadic) + opClass.addTrait("OpTrait::Variadic" + traitKind + "s"); + else + opClass.addTrait("OpTrait::AtLeastN" + traitKind + "s<" + + Twine(numNonVariadic - numVariadic) + ">::Impl"); + return; + } + switch (numNonVariadic) { + case 0: + opClass.addTrait("OpTrait::Zero" + traitKind); + break; + case 1: + opClass.addTrait("OpTrait::One" + traitKind); + break; + default: + opClass.addTrait("OpTrait::N" + traitKind + "s<" + Twine(numNonVariadic) + + ">::Impl"); + break; + } +} + void OpEmitter::genTraits() { int numResults = op.getNumResults(); int numVariadicResults = op.getNumVariadicResults(); // Add return size trait. - if (numVariadicResults != 0) { - if (numResults == numVariadicResults) - opClass.addTrait("OpTrait::VariadicResults"); - else - opClass.addTrait("OpTrait::AtLeastNResults<" + - Twine(numResults - numVariadicResults) + ">::Impl"); - } else { - switch (numResults) { - case 0: - opClass.addTrait("OpTrait::ZeroResult"); - break; - case 1: - opClass.addTrait("OpTrait::OneResult"); - break; - default: - opClass.addTrait("OpTrait::NResults<" + Twine(numResults) + ">::Impl"); - break; - } - } + addSizeCountTrait(opClass, "Result", numResults, numVariadicResults); + + // Add successor size trait. + unsigned numSuccessors = op.getNumSuccessors(); + unsigned numVariadicSuccessors = op.getNumVariadicSuccessors(); + addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors); + // Add the native and interface traits. for (const auto &trait : op.getTraits()) { if (auto opTrait = dyn_cast(&trait)) opClass.addTrait(opTrait->getTrait());