diff --git a/mlir/include/mlir/Analysis/CMakeLists.txt b/mlir/include/mlir/Analysis/CMakeLists.txt --- a/mlir/include/mlir/Analysis/CMakeLists.txt +++ b/mlir/include/mlir/Analysis/CMakeLists.txt @@ -3,6 +3,11 @@ mlir_tablegen(CallInterfaces.cpp.inc -gen-op-interface-defs) add_public_tablegen_target(MLIRCallOpInterfacesIncGen) +set(LLVM_TARGET_DEFINITIONS ControlFlowInterfaces.td) +mlir_tablegen(ControlFlowInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(ControlFlowInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRControlFlowInterfacesIncGen) + set(LLVM_TARGET_DEFINITIONS InferTypeOpInterface.td) mlir_tablegen(InferTypeOpInterface.h.inc -gen-op-interface-decls) mlir_tablegen(InferTypeOpInterface.cpp.inc -gen-op-interface-defs) diff --git a/mlir/include/mlir/Analysis/ControlFlowInterfaces.h b/mlir/include/mlir/Analysis/ControlFlowInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/ControlFlowInterfaces.h @@ -0,0 +1,43 @@ +//===- ControlFlowInterfaces.h - ControlFlow Interfaces ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definitions of the branch interfaces defined in +// `ControlFlowInterfaces.td`. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_CONTROLFLOWINTERFACES_H +#define MLIR_ANALYSIS_CONTROLFLOWINTERFACES_H + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +class BranchOpInterface; + +namespace detail { +/// Erase an operand from a branch operation that is used as a successor +/// operand. `operandIndex` is the operand within `operands` to be erased. +void eraseBranchSuccessorOperand(OperandRange operands, unsigned operandIndex, + Operation *op); + +/// Return the `BlockArgument` corresponding to operand `operandIndex` in some +/// successor if `operandIndex` is within the range of `operands`, or None if +/// `operandIndex` isn't a successor operand index. +Optional +getBranchSuccessorArgument(Optional operands, + unsigned operandIndex, Block *successor); + +/// Verify that the given operands match those of the given successor block. +LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo, + Optional operands); +} // namespace detail + +#include "mlir/Analysis/ControlFlowInterfaces.h.inc" +} // end namespace mlir + +#endif // MLIR_ANALYSIS_CONTROLFLOWINTERFACES_H diff --git a/mlir/include/mlir/Analysis/ControlFlowInterfaces.td b/mlir/include/mlir/Analysis/ControlFlowInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/ControlFlowInterfaces.td @@ -0,0 +1,85 @@ +//===-- ControlFlowInterfaces.td - ControlFlow Interfaces --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a set of interfaces that can be used to define information +// about control flow operations, e.g. branches. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_CONTROLFLOWINTERFACES +#define MLIR_ANALYSIS_CONTROLFLOWINTERFACES + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// BranchOpInterface +//===----------------------------------------------------------------------===// + +def BranchOpInterface : OpInterface<"BranchOpInterface"> { + let description = [{ + This interface provides information for branching terminator operations, + i.e. terminator operations with successors. + }]; + let methods = [ + InterfaceMethod<[{ + Returns a set of values that correspond to the arguments to the + successor at the given index. Returns None if the operands to the + successor are non-materialized values, i.e. they are internal to the + operation. + }], + "Optional", "getSuccessorOperands", (ins "unsigned":$index) + >, + InterfaceMethod<[{ + Return true if this operation can erase an operand to a successor block. + }], + "bool", "canEraseSuccessorOperand" + >, + InterfaceMethod<[{ + Erase the operand at `operandIndex` from the `index`-th successor. This + should only be called if `canEraseSuccessorOperand` returns true. + }], + "void", "eraseSuccessorOperand", + (ins "unsigned":$index, "unsigned":$operandIndex), [{}], + /*defaultImplementation=*/[{ + ConcreteOp *op = static_cast(this); + Optional operands = op->getSuccessorOperands(index); + assert(operands && "unable to query operands for successor"); + detail::eraseBranchSuccessorOperand(*operands, operandIndex, *op); + }] + >, + InterfaceMethod<[{ + Returns the `BlockArgument` corresponding to operand `operandIndex` in + some successor, or None if `operandIndex` isn't a successor operand + index. + }], + "Optional", "getSuccessorBlockArgument", + (ins "unsigned":$operandIndex), [{ + Operation *opaqueOp = op; + for (unsigned i = 0, e = opaqueOp->getNumSuccessors(); i != e; ++i) { + if (Optional arg = detail::getBranchSuccessorArgument( + op.getSuccessorOperands(i), operandIndex, + opaqueOp->getSuccessor(i))) + return arg; + } + return llvm::None; + }] + > + ]; + + let verify = [{ + auto concreteOp = cast($_op); + for (unsigned i = 0, e = $_op->getNumSuccessors(); i != e; ++i) { + Optional operands = concreteOp.getSuccessorOperands(i); + if (failed(detail::verifyBranchSuccessorOperands($_op, i, operands))) + return failure(); + } + return success(); + }]; +} + +#endif // MLIR_ANALYSIS_CONTROLFLOWINTERFACES diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ #define MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ +#include "mlir/Analysis/ControlFlowInterfaces.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -14,6 +14,7 @@ #define LLVMIR_OPS include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Analysis/ControlFlowInterfaces.td" class LLVM_Builder { string llvmBuilder = builder; @@ -315,7 +316,9 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc">; // Call-related operations. -def LLVM_InvokeOp : LLVM_Op<"invoke", [Terminator]>, +def LLVM_InvokeOp : LLVM_Op<"invoke", [ + DeclareOpInterfaceMethods, Terminator + ]>, Arguments<(ins OptionalAttr:$callee, Variadic)>, Results<(outs Variadic)> { @@ -458,11 +461,13 @@ } // Terminators. -def LLVM_BrOp : LLVM_TerminatorOp<"br", []> { +def LLVM_BrOp : LLVM_TerminatorOp<"br", + [DeclareOpInterfaceMethods]> { let successors = (successor AnySuccessor:$dest); let assemblyFormat = "$dest attr-dict"; } -def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", []> { +def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", + [DeclareOpInterfaceMethods]> { let arguments = (ins LLVMI1:$condition); let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); let assemblyFormat = "$condition `,` successors attr-dict"; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -16,10 +16,13 @@ include "mlir/Dialect/SPIRV/SPIRVBase.td" include "mlir/Analysis/CallInterfaces.td" +include "mlir/Analysis/ControlFlowInterfaces.td" // ----- -def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> { +def SPV_BranchOp : SPV_Op<"Branch",[ + DeclareOpInterfaceMethods, InFunctionScope, + Terminator]> { let summary = "Unconditional branch to target block."; let description = [{ @@ -75,8 +78,9 @@ // ----- -def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", - [InFunctionScope, Terminator]> { +def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [ + DeclareOpInterfaceMethods, InFunctionScope, + Terminator]> { let summary = [{ If Condition is true, branch to true block, otherwise branch to false block. diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SPIRV_SPIRVOPS_H_ #define MLIR_DIALECT_SPIRV_SPIRVOPS_H_ +#include "mlir/Analysis/ControlFlowInterfaces.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Function.h" #include "llvm/Support/PointerLikeTypeTraits.h" diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -15,6 +15,7 @@ #define MLIR_DIALECT_STANDARDOPS_IR_OPS_H #include "mlir/Analysis/CallInterfaces.h" +#include "mlir/Analysis/ControlFlowInterfaces.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -14,6 +14,7 @@ #define STANDARD_OPS include "mlir/Analysis/CallInterfaces.td" +include "mlir/Analysis/ControlFlowInterfaces.td" include "mlir/IR/OpAsmInterface.td" def Std_Dialect : Dialect { @@ -331,7 +332,8 @@ // BranchOp //===----------------------------------------------------------------------===// -def BranchOp : Std_Op<"br", [Terminator]> { +def BranchOp : Std_Op<"br", + [DeclareOpInterfaceMethods, Terminator]> { let summary = "branch operation"; let description = [{ The "br" operation represents a branch operation in a function. @@ -668,7 +670,8 @@ // CondBranchOp //===----------------------------------------------------------------------===// -def CondBranchOp : Std_Op<"cond_br", [Terminator]> { +def CondBranchOp : Std_Op<"cond_br", + [DeclareOpInterfaceMethods, Terminator]> { let summary = "conditional branch operation"; let description = [{ The "cond_br" operation represents a conditional branch operation in a diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -639,6 +639,10 @@ type_range getTypes() const { return {begin(), end()}; } auto getType() const { return getTypes(); } + /// Return the operand index of the first element of this range. The range + /// must not be empty. + unsigned getBeginOperandIndex() const; + private: /// See `detail::indexed_accessor_range_base` for details. static OpOperand *offset_base(OpOperand *object, ptrdiff_t index) { diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -2,6 +2,7 @@ AffineAnalysis.cpp AffineStructures.cpp CallGraph.cpp + ControlFlowInterfaces.cpp Dominance.cpp InferTypeOpInterface.cpp Liveness.cpp @@ -14,6 +15,7 @@ add_llvm_library(MLIRAnalysis CallGraph.cpp + ControlFlowInterfaces.cpp InferTypeOpInterface.cpp Liveness.cpp SliceAnalysis.cpp @@ -26,6 +28,7 @@ add_dependencies(MLIRAnalysis MLIRAffineOps MLIRCallOpInterfacesIncGen + MLIRControlFlowInterfacesIncGen MLIRTypeInferOpInterfaceIncGen MLIRLoopOps ) @@ -45,6 +48,7 @@ add_dependencies(MLIRLoopAnalysis MLIRAffineOps MLIRCallOpInterfacesIncGen + MLIRControlFlowInterfacesIncGen MLIRTypeInferOpInterfaceIncGen MLIRLoopOps ) diff --git a/mlir/lib/Analysis/ControlFlowInterfaces.cpp b/mlir/lib/Analysis/ControlFlowInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/ControlFlowInterfaces.cpp @@ -0,0 +1,101 @@ +//===- ControlFlowInterfaces.h - ControlFlow Interfaces -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/ControlFlowInterfaces.h" +#include "mlir/IR/StandardTypes.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ControlFlowInterfaces +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/ControlFlowInterfaces.cpp.inc" + +//===----------------------------------------------------------------------===// +// BranchOpInterface +//===----------------------------------------------------------------------===// + +/// Erase an operand from a branch operation that is used as a successor +/// operand. 'operandIndex' is the operand within 'operands' to be erased. +void mlir::detail::eraseBranchSuccessorOperand(OperandRange operands, + unsigned operandIndex, + Operation *op) { + assert(operandIndex < operands.size() && + "invalid index for successor operands"); + + // Erase the operand from the operation. + size_t fullOperandIndex = operands.getBeginOperandIndex() + operandIndex; + op->eraseOperand(fullOperandIndex); + + // If this operation has an OperandSegmentSizeAttr, keep it up to date. + auto operandSegmentAttr = + op->getAttrOfType("operand_segment_sizes"); + if (!operandSegmentAttr) + return; + + // Find the segment containing the full operand index and decrement it. + // TODO: This seems like a general utility that could be added somewhere. + SmallVector values(operandSegmentAttr.getValues()); + unsigned currentSize = 0; + for (unsigned i = 0, e = values.size(); i != e; ++i) { + currentSize += values[i]; + if (fullOperandIndex < currentSize) { + --values[i]; + break; + } + } + op->setAttr("operand_segment_sizes", + DenseIntElementsAttr::get(operandSegmentAttr.getType(), values)); +} + +/// Returns the `BlockArgument` corresponding to operand `operandIndex` in some +/// successor if 'operandIndex' is within the range of 'operands', or None if +/// `operandIndex` isn't a successor operand index. +Optional mlir::detail::getBranchSuccessorArgument( + Optional operands, unsigned operandIndex, Block *successor) { + // Check that the operands are valid. + if (!operands || operands->empty()) + return llvm::None; + + // Check to ensure that this operand is within the range. + unsigned operandsStart = operands->getBeginOperandIndex(); + if (operandIndex < operandsStart || + operandIndex >= (operandsStart + operands->size())) + return llvm::None; + + // Index the successor. + unsigned argIndex = operandIndex - operandsStart; + return successor->getArgument(argIndex); +} + +/// Verify that the given operands match those of the given successor block. +LogicalResult +mlir::detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, + Optional operands) { + if (!operands) + return success(); + + // Check the count. + unsigned operandCount = operands->size(); + Block *destBB = op->getSuccessor(succNo); + if (operandCount != destBB->getNumArguments()) + return op->emitError() << "branch has " << operandCount + << " operands for successor #" << succNo + << ", but target block has " + << destBB->getNumArguments(); + + // Check the types. + auto operandIt = operands->begin(); + for (unsigned i = 0; i != operandCount; ++i, ++operandIt) { + if ((*operandIt).getType() != destBB->getArgument(i).getType()) + return op->emitError() << "type mismatch for bb argument #" << i + << " of successor #" << succNo; + } + return success(); +} diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -4,7 +4,7 @@ ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR ) -add_dependencies(MLIRLLVMIR MLIRLLVMOpsIncGen MLIRLLVMConversionsIncGen MLIROpenMP LLVMFrontendOpenMP LLVMAsmParser LLVMCore LLVMSupport) +add_dependencies(MLIRLLVMIR MLIRControlFlowInterfacesIncGen MLIRLLVMOpsIncGen MLIRLLVMConversionsIncGen MLIROpenMP LLVMFrontendOpenMP LLVMAsmParser LLVMCore LLVMSupport) target_link_libraries(MLIRLLVMIR LLVMAsmParser LLVMCore LLVMSupport LLVMFrontendOpenMP MLIROpenMP MLIRIR) add_mlir_dialect_library(MLIRNVVMIR diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -154,6 +154,28 @@ } //===----------------------------------------------------------------------===// +// LLVM::BrOp +//===----------------------------------------------------------------------===// + +Optional BrOp::getSuccessorOperands(unsigned index) { + assert(index == 0 && "invalid successor index"); + return getOperands(); +} + +bool BrOp::canEraseSuccessorOperand() { return true; } + +//===----------------------------------------------------------------------===// +// LLVM::CondBrOp +//===----------------------------------------------------------------------===// + +Optional CondBrOp::getSuccessorOperands(unsigned index) { + assert(index < getNumSuccessors() && "invalid successor index"); + return index == 0 ? trueDestOperands() : falseDestOperands(); +} + +bool CondBrOp::canEraseSuccessorOperand() { return true; } + +//===----------------------------------------------------------------------===// // Printing/parsing for LLVM::LoadOp. //===----------------------------------------------------------------------===// @@ -229,9 +251,16 @@ } ///===----------------------------------------------------------------------===// -/// Verifying/Printing/Parsing for LLVM::InvokeOp. +/// LLVM::InvokeOp ///===----------------------------------------------------------------------===// +Optional InvokeOp::getSuccessorOperands(unsigned index) { + assert(index < getNumSuccessors() && "invalid successor index"); + return index == 0 ? normalDestOperands() : unwindDestOperands(); +} + +bool InvokeOp::canEraseSuccessorOperand() { return true; } + static LogicalResult verify(InvokeOp op) { if (op.getNumResults() > 1) return op.emitOpError("must have 0 or 1 result"); diff --git a/mlir/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/CMakeLists.txt @@ -16,6 +16,7 @@ ) add_dependencies(MLIRSPIRV + MLIRControlFlowInterfacesIncGen MLIRSPIRVAvailabilityIncGen MLIRSPIRVCanonicalizationIncGen MLIRSPIRVEnumAvailabilityIncGen diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -943,9 +943,29 @@ } //===----------------------------------------------------------------------===// +// spv.BranchOp +//===----------------------------------------------------------------------===// + +Optional spirv::BranchOp::getSuccessorOperands(unsigned index) { + assert(index == 0 && "invalid successor index"); + return getOperands(); +} + +bool spirv::BranchOp::canEraseSuccessorOperand() { return true; } + +//===----------------------------------------------------------------------===// // spv.BranchConditionalOp //===----------------------------------------------------------------------===// +Optional +spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) { + assert(index < 2 && "invalid successor index"); + return index == kTrueIndex ? getTrueBlockArguments() + : getFalseBlockArguments(); +} + +bool spirv::BranchConditionalOp::canEraseSuccessorOperand() { return true; } + static ParseResult parseBranchConditionalOp(OpAsmParser &parser, OperationState &state) { auto &builder = parser.getBuilder(); diff --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt @@ -9,6 +9,7 @@ add_dependencies(MLIRStandardOps MLIRCallOpInterfacesIncGen + MLIRControlFlowInterfacesIncGen MLIREDSC MLIRIR MLIRStandardOpsIncGen diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -482,7 +482,7 @@ void BranchOp::setDest(Block *block) { return setSuccessor(block); } void BranchOp::eraseOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(0, index); + getOperation()->eraseOperand(index); } void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, @@ -490,6 +490,13 @@ results.insert(context); } +Optional BranchOp::getSuccessorOperands(unsigned index) { + assert(index == 0 && "invalid successor index"); + return getOperands(); +} + +bool BranchOp::canEraseSuccessorOperand() { return true; } + //===----------------------------------------------------------------------===// // CallOp //===----------------------------------------------------------------------===// @@ -749,6 +756,13 @@ results.insert(context); } +Optional CondBranchOp::getSuccessorOperands(unsigned index) { + assert(index < getNumSuccessors() && "invalid successor index"); + return index == trueIndex ? getTrueOperands() : getFalseOperands(); +} + +bool CondBranchOp::canEraseSuccessorOperand() { return true; } + //===----------------------------------------------------------------------===// // Constant*Op //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -950,37 +950,13 @@ return success(); } -static LogicalResult verifySuccessor(Operation *op, unsigned succNo) { - Operation::operand_range operands = op->getSuccessorOperands(succNo); - unsigned operandCount = op->getNumSuccessorOperands(succNo); - Block *destBB = op->getSuccessor(succNo); - if (operandCount != destBB->getNumArguments()) - return op->emitError() << "branch has " << operandCount - << " operands for successor #" << succNo - << ", but target block has " - << destBB->getNumArguments(); - - auto operandIt = operands.begin(); - for (unsigned i = 0, e = operandCount; i != e; ++i, ++operandIt) { - if ((*operandIt).getType() != destBB->getArgument(i).getType()) - return op->emitError() << "type mismatch for bb argument #" << i - << " of successor #" << succNo; - } - - return success(); -} - static LogicalResult verifyTerminatorSuccessors(Operation *op) { auto *parent = op->getParentRegion(); // Verify that the operands lines up with the BB arguments in the successor. - for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { - auto *succ = op->getSuccessor(i); + for (Block *succ : op->getSuccessors()) if (succ->getParent() != parent) return op->emitError("reference to block defined in another region"); - if (failed(verifySuccessor(op, i))) - return failure(); - } return success(); } diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -183,6 +183,13 @@ OperandRange::OperandRange(Operation *op) : OperandRange(op->getOpOperands().data(), op->getNumOperands()) {} +/// Return the operand index of the first element of this range. The range +/// must not be empty. +unsigned OperandRange::getBeginOperandIndex() const { + assert(!empty() && "range must not be empty"); + return base->getOperandNumber(); +} + //===----------------------------------------------------------------------===// // ResultRange diff --git a/mlir/test/lib/TestDialect/CMakeLists.txt b/mlir/test/lib/TestDialect/CMakeLists.txt --- a/mlir/test/lib/TestDialect/CMakeLists.txt +++ b/mlir/test/lib/TestDialect/CMakeLists.txt @@ -16,6 +16,7 @@ TestPatterns.cpp ) add_dependencies(MLIRTestDialect + MLIRControlFlowInterfacesIncGen MLIRTestOpsIncGen MLIRTypeInferOpInterfaceIncGen ) diff --git a/mlir/test/lib/TestDialect/TestDialect.h b/mlir/test/lib/TestDialect/TestDialect.h --- a/mlir/test/lib/TestDialect/TestDialect.h +++ b/mlir/test/lib/TestDialect/TestDialect.h @@ -15,6 +15,7 @@ #define MLIR_TESTDIALECT_H #include "mlir/Analysis/CallInterfaces.h" +#include "mlir/Analysis/ControlFlowInterfaces.h" #include "mlir/Analysis/InferTypeOpInterface.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -164,6 +164,17 @@ } //===----------------------------------------------------------------------===// +// TestBranchOp +//===----------------------------------------------------------------------===// + +Optional TestBranchOp::getSuccessorOperands(unsigned index) { + assert(index == 0 && "invalid successor index"); + return getOperands(); +} + +bool TestBranchOp::canEraseSuccessorOperand() { return true; } + +//===----------------------------------------------------------------------===// // Test IsolatedRegionOp - parse passthrough region arguments. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -11,6 +11,7 @@ include "mlir/IR/OpBase.td" include "mlir/IR/OpAsmInterface.td" +include "mlir/Analysis/ControlFlowInterfaces.td" include "mlir/Analysis/CallInterfaces.td" include "mlir/Analysis/InferTypeOpInterface.td" @@ -446,6 +447,11 @@ ]> { let arguments = (ins AnyTensor, AnyTensor); let results = (outs AnyTensor); + + let extraClassDeclaration = [{ + LogicalResult reifyReturnTypeShapes(OpBuilder &builder, + SmallVectorImpl &shapes); + }]; } def IsNotScalar : Constraint>; @@ -454,7 +460,8 @@ (I32ElementsAttrOp ConstantAttr), [(IsNotScalar $attr)]>; -def TestBranchOp : TEST_Op<"br", [Terminator]> { +def TestBranchOp : TEST_Op<"br", + [DeclareOpInterfaceMethods, Terminator]> { let successors = (successor AnySuccessor:$target); } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1155,8 +1155,8 @@ continue; auto interface = opTrait->getOpInterface(); for (auto method : interface.getMethods()) { - // Don't declare if the method has a body. - if (method.getBody()) + // Don't declare if the method has a body or a default implementation. + if (method.getBody() || method.getDefaultImplementation()) continue; std::string args; llvm::raw_string_ostream os(args);