diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -489,16 +489,12 @@ llvm::ArrayRef operands, unsigned cond); llvm::Optional getSuccessorOperands( mlir::ValueRange operands, unsigned cond); - using BranchOpInterfaceTrait::getSuccessorOperands; // Helper function to deal with Optional operand forms void printSuccessorAtIndex(mlir::OpAsmPrinter &p, unsigned i) { auto *succ = getSuccessor(i); auto ops = getSuccessorOperands(i); - if (ops.hasValue()) - p.printSuccessorAndUseList(succ, ops.getValue()); - else - p.printSuccessor(succ); + p.printSuccessorAndUseList(succ, ops.getForwardedOperands()); } mlir::ArrayAttr getCases() { diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -2401,10 +2401,9 @@ return {}; } -llvm::Optional -fir::SelectOp::getMutableSuccessorOperands(unsigned oper) { - return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(), - getTargetOffsetAttr()); +mlir::SuccessorOperands fir::SelectOp::getSuccessorOperands(unsigned oper) { + return mlir::SuccessorOperands(::getMutableSuccessorOperands( + oper, getTargetArgsMutable(), getTargetOffsetAttr())); } llvm::Optional> @@ -2462,10 +2461,9 @@ return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; } -llvm::Optional -fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) { - return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(), - getTargetOffsetAttr()); +mlir::SuccessorOperands fir::SelectCaseOp::getSuccessorOperands(unsigned oper) { + return mlir::SuccessorOperands(::getMutableSuccessorOperands( + oper, getTargetArgsMutable(), getTargetOffsetAttr())); } llvm::Optional> @@ -2734,10 +2732,9 @@ return {}; } -llvm::Optional -fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) { - return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(), - getTargetOffsetAttr()); +mlir::SuccessorOperands fir::SelectRankOp::getSuccessorOperands(unsigned oper) { + return mlir::SuccessorOperands(::getMutableSuccessorOperands( + oper, getTargetArgsMutable(), getTargetOffsetAttr())); } llvm::Optional> @@ -2779,10 +2776,9 @@ return {}; } -llvm::Optional -fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) { - return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(), - getTargetOffsetAttr()); +mlir::SuccessorOperands fir::SelectTypeOp::getSuccessorOperands(unsigned oper) { + return mlir::SuccessorOperands(::getMutableSuccessorOperands( + oper, getTargetArgsMutable(), getTargetOffsetAttr())); } llvm::Optional> diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -907,6 +907,11 @@ /// elements attribute, which contains the sizes of the sub ranges. MutableOperandRangeRange split(NamedAttribute segmentSizes) const; + /// Returns the value at the given index. + Value operator[](unsigned index) const { + return static_cast(*this)[index]; + } + private: /// Update the length of this range to the one provided. void updateLength(unsigned newLength); diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -20,6 +20,106 @@ class BranchOpInterface; class RegionBranchOpInterface; +/// This class models how operands are forwarded to block arguments in control +/// flow. It consists of a number, denoting how many of the successors block +/// arguments are produced by the operation, followed by a range of operands +/// that are forwarded. The produced operands are passed to the first few +/// block arguments of the successor, followed by the forwarded operands. +/// It is unsupported to pass them in a different order. +/// +/// An example operation with both of these concepts would be a branch-on-error +/// operation, that internally produces an error object on the error path: +/// +/// invoke %function(%0) +/// label ^success ^error(%1 : i32) +/// +/// ^error(%e: !error, %arg0 : i32): +/// ... +/// +/// This operation would return an instance of SuccessorOperands with a produced +/// operand count of 1 (mapped to %e in the successor) and a forwarded +/// operands range consisting of %1 in the example above (mapped to %arg0 in the +/// successor). +class SuccessorOperands { +public: + /// Constructs a SuccessorOperands with no produced operands that simply + /// forwards operands to the successor. + explicit SuccessorOperands(MutableOperandRange forwardedOperands); + + /// Constructs a SuccessorOperands with the given amount of produced operands + /// and forwarded operands. + SuccessorOperands(unsigned producedOperandCount, + MutableOperandRange forwardedOperands); + + /// Returns the amount of operands passed to the successor. This consists both + /// of produced operands by the operation as well as forwarded ones. + unsigned size() const { + return producedOperandCount + forwardedOperands.size(); + } + + /// Returns true if there are no successor operands. + bool empty() const { return size() == 0; } + + /// Returns the amount of operands that are produced internally by the + /// operation. These are passed to the first few block arguments. + unsigned getProducedOperandCount() const { return producedOperandCount; } + + /// Returns true if the successor operand denoted by `index` is produced by + /// the operation. + bool isOperandProduced(unsigned index) const { + return index < producedOperandCount; + } + + /// Returns the Value that is passed to the successors block argument denoted + /// by `index`. If it is produced by the operation, no such value exists and + /// a null Value is returned. + Value operator[](unsigned index) const { + if (isOperandProduced(index)) + return Value(); + return forwardedOperands[index - producedOperandCount]; + } + + /// Get the range of operands that are simply forwarded to the successor. + OperandRange getForwardedOperands() const { return forwardedOperands; } + + /// Get a slice of the operands forwarded to the successor. The given range + /// must not contain any operands produced by the operation. + MutableOperandRange slice(unsigned subStart, unsigned subLen) const { + assert(!isOperandProduced(subStart) && + "can't slice operands produced by the operation"); + return forwardedOperands.slice(subStart - producedOperandCount, subLen); + } + + /// Erase operands forwarded to the successor. The given range must + /// not contain any operands produced by the operation. + void erase(unsigned subStart, unsigned subLen = 1) { + assert(!isOperandProduced(subStart) && + "can't erase operands produced by the operation"); + forwardedOperands.erase(subStart - producedOperandCount, subLen); + } + + /// Add new operands that are forwarded to the successor. + void append(ValueRange valueRange) { forwardedOperands.append(valueRange); } + + /// Gets the index of the forwarded operand within the operation which maps + /// to the block argument denoted by `blockArgumentIndex`. The block argument + /// must be mapped to a forwarded operand. + unsigned getOperandIndex(unsigned blockArgumentIndex) const { + assert(!isOperandProduced(blockArgumentIndex) && + "can't map operand produced by the operation"); + return static_cast(forwardedOperands) + .getBeginOperandIndex() + + (blockArgumentIndex - producedOperandCount); + } + +private: + /// Amount of operands that are produced internally within the operation and + /// passed to the first few block arguments. + unsigned producedOperandCount; + /// Range of operands that are forwarded to the remaining block arguments. + MutableOperandRange forwardedOperands; +}; + //===----------------------------------------------------------------------===// // BranchOpInterface //===----------------------------------------------------------------------===// @@ -29,12 +129,12 @@ /// successor if `operandIndex` is within the range of `operands`, or None if /// `operandIndex` isn't a successor operand index. Optional -getBranchSuccessorArgument(Optional operands, +getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor); /// Verify that the given operands match those of the given successor block. LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo, - Optional operands); + const SuccessorOperands &operands); } // namespace detail //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -36,26 +36,35 @@ let methods = [ InterfaceMethod<[{ - Returns a mutable range of operands that correspond to the arguments of - successor at the given index. Returns None if the operands to the - successor are non-materialized values, i.e. they are internal to the - operation. + Returns the operands that correspond to the arguments of the successor + at the given index. It consists of a number of operands that are + internally produced by the operation, followed by a range of operands + that are forwarded. An example operation making use of produced + operands would be: + + ```mlir + invoke %function(%0) + label ^success ^error(%1 : i32) + + ^error(%e: !error, %arg0: i32): + ... + ``` + + The operand that would map to the `^error`s `%e` operand is produced + by the `invoke` operation, while `%1` is a forwarded operand that maps + to `%arg0` in the successor. + + Produced operands always map to the first few block arguments of the + successor, followed by the forwarded operands. Mapping them in any + other order is not supported by the interface. + + By having the forwarded operands last allows users of the interface + to append more forwarded operands to the branch operation without + interfering with other successor operands. }], - "::mlir::Optional<::mlir::MutableOperandRange>", "getMutableSuccessorOperands", + "::mlir::SuccessorOperands", "getSuccessorOperands", (ins "unsigned":$index) >, - InterfaceMethod<[{ - Returns a range of operands that correspond to the arguments of - successor at the given index. Returns None if the operands to the - successor are non-materialized values, i.e. they are internal to the - operation. - }], - "::mlir::Optional<::mlir::OperandRange>", "getSuccessorOperands", - (ins "unsigned":$index), [{}], [{ - auto operands = $_op.getMutableSuccessorOperands(index); - return operands ? ::mlir::Optional<::mlir::OperandRange>(*operands) : ::llvm::None; - }] - >, InterfaceMethod<[{ Returns the `BlockArgument` corresponding to operand `operandIndex` in some successor, or None if `operandIndex` isn't a successor operand @@ -94,7 +103,7 @@ let verify = [{ auto concreteOp = ::mlir::cast($_op); for (unsigned i = 0, e = $_op->getNumSuccessors(); i != e; ++i) { - ::mlir::Optional operands = concreteOp.getSuccessorOperands(i); + ::mlir::SuccessorOperands operands = concreteOp.getSuccessorOperands(i); if (::mlir::failed(::mlir::detail::verifyBranchSuccessorOperands($_op, i, operands))) return ::mlir::failure(); } diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -149,14 +149,13 @@ // Try to get the operand passed for this argument. unsigned index = it.getSuccessorIndex(); - Optional operands = branch.getSuccessorOperands(index); - if (!operands) { + Value operand = branch.getSuccessorOperands(index)[argNumber]; + if (!operand) { // We can't analyze the control flow, so bail out early. output.push_back(arg); return; } - collectUnderlyingAddressValues((*operands)[argNumber], maxDepth, visited, - output); + collectUnderlyingAddressValues(operand, maxDepth, visited, output); } return; } diff --git a/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp b/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp --- a/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp +++ b/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp @@ -70,10 +70,10 @@ // Query the branch op interface to get the successor operands. auto successorOperands = branchInterface.getSuccessorOperands(it.getIndex()); - if (!successorOperands.hasValue()) - continue; // Build the actual mapping of values to their immediate dependencies. - registerDependencies(successorOperands.getValue(), (*it)->getArguments()); + registerDependencies(successorOperands.getForwardedOperands(), + (*it)->getArguments().drop_front( + successorOperands.getProducedOperandCount())); } }); diff --git a/mlir/lib/Analysis/DataFlowAnalysis.cpp b/mlir/lib/Analysis/DataFlowAnalysis.cpp --- a/mlir/lib/Analysis/DataFlowAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlowAnalysis.cpp @@ -681,10 +681,13 @@ // Try to get the operand forwarded by the predecessor. If we can't reason // about the terminator of the predecessor, mark as having reached a // fixpoint. - Optional branchOperands; - if (auto branch = dyn_cast(pred->getTerminator())) - branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex()); - if (!branchOperands) { + auto branch = dyn_cast(pred->getTerminator()); + if (!branch) { + updatedLattice |= argLattice.markPessimisticFixpoint(); + break; + } + Value operand = branch.getSuccessorOperands(it.getSuccessorIndex())[i]; + if (!operand) { updatedLattice |= argLattice.markPessimisticFixpoint(); break; } @@ -692,7 +695,7 @@ // If the operand hasn't been resolved, it is uninitialized and can merge // with anything. AbstractLatticeElement *operandLattice = - analysis.lookupLatticeElement((*branchOperands)[i]); + analysis.lookupLatticeElement(operand); if (!operandLattice) continue; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp @@ -325,25 +325,20 @@ // argument. Operation *terminator = (*it)->getTerminator(); auto branchInterface = cast(terminator); + SuccessorOperands operands = + branchInterface.getSuccessorOperands(it.getSuccessorIndex()); + // Query the associated source value. - Value sourceValue = - branchInterface.getSuccessorOperands(it.getSuccessorIndex()) - .getValue()[blockArg.getArgNumber()]; - // Wire new clone and successor operand. - auto mutableOperands = - branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex()); - if (!mutableOperands) { - terminator->emitError() << "terminators with immutable successor " - "operands are not supported"; - continue; + Value sourceValue = operands[blockArg.getArgNumber()]; + if (!sourceValue) { + return failure(); } + // Wire new clone and successor operand. // Create a new clone at the current location of the terminator. auto clone = introduceCloneBuffers(sourceValue, terminator); if (failed(clone)) return failure(); - mutableOperands.getValue() - .slice(blockArg.getArgNumber(), 1) - .assign(*clone); + operands.slice(blockArg.getArgNumber(), 1).assign(*clone); } // Check whether the block argument has implicitly defined predecessors via diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -186,10 +186,9 @@ void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); } -Optional -BranchOp::getMutableSuccessorOperands(unsigned index) { +SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return getDestOperandsMutable(); + return SuccessorOperands(getDestOperandsMutable()); } Block *BranchOp::getSuccessorForOperands(ArrayRef) { @@ -437,11 +436,10 @@ CondBranchTruthPropagation>(context); } -Optional -CondBranchOp::getMutableSuccessorOperands(unsigned index) { +SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); - return index == trueIndex ? getTrueDestOperandsMutable() - : getFalseDestOperandsMutable(); + return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable() + : getFalseDestOperandsMutable()); } Block *CondBranchOp::getSuccessorForOperands(ArrayRef operands) { @@ -575,11 +573,10 @@ return success(); } -Optional -SwitchOp::getMutableSuccessorOperands(unsigned index) { +SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); - return index == 0 ? getDefaultOperandsMutable() - : getCaseOperandsMutable(index - 1); + return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() + : getCaseOperandsMutable(index - 1)); } Block *SwitchOp::getSuccessorForOperands(ArrayRef operands) { diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp --- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp @@ -67,12 +67,13 @@ SmallVector newOperands(op->operand_begin(), op->operand_end()); for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors(); succIdx < succEnd; ++succIdx) { - auto successorOperands = op.getSuccessorOperands(succIdx); - if (!successorOperands || successorOperands->empty()) + OperandRange forwardedOperands = + op.getSuccessorOperands(succIdx).getForwardedOperands(); + if (forwardedOperands.empty()) continue; - for (int idx = successorOperands->getBeginOperandIndex(), - eidx = idx + successorOperands->size(); + for (int idx = forwardedOperands.getBeginOperandIndex(), + eidx = idx + forwardedOperands.size(); idx < eidx; ++idx) { if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx)) newOperands[idx] = operands[idx]; @@ -121,8 +122,8 @@ if (auto branchOp = dyn_cast(op)) { for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) { auto successorOperands = branchOp.getSuccessorOperands(p); - if (successorOperands.hasValue() && - !converter.isLegal(successorOperands.getValue().getTypes())) + if (!converter.isLegal( + successorOperands.getForwardedOperands().getTypes())) return false; } return true; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -240,21 +240,19 @@ // LLVM::BrOp //===----------------------------------------------------------------------===// -Optional -BrOp::getMutableSuccessorOperands(unsigned index) { +SuccessorOperands BrOp::getSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return getDestOperandsMutable(); + return SuccessorOperands(getDestOperandsMutable()); } //===----------------------------------------------------------------------===// // LLVM::CondBrOp //===----------------------------------------------------------------------===// -Optional -CondBrOp::getMutableSuccessorOperands(unsigned index) { +SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); - return index == 0 ? getTrueDestOperandsMutable() - : getFalseDestOperandsMutable(); + return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable() + : getFalseDestOperandsMutable()); } //===----------------------------------------------------------------------===// @@ -356,11 +354,10 @@ return success(); } -Optional -SwitchOp::getMutableSuccessorOperands(unsigned index) { +SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); - return index == 0 ? getDefaultOperandsMutable() - : getCaseOperandsMutable(index - 1); + return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() + : getCaseOperandsMutable(index - 1)); } //===----------------------------------------------------------------------===// @@ -735,11 +732,10 @@ /// LLVM::InvokeOp ///===---------------------------------------------------------------------===// -Optional -InvokeOp::getMutableSuccessorOperands(unsigned index) { +SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); - return index == 0 ? getNormalDestOperandsMutable() - : getUnwindDestOperandsMutable(); + return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable() + : getUnwindDestOperandsMutable()); } LogicalResult InvokeOp::verify() { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -223,12 +223,12 @@ auto blockOperands = terminator.getSuccessorOperands(pred.getSuccessorIndex()); - if (!blockOperands || blockOperands->empty()) + if (blockOperands.empty() || + blockOperands.isOperandProduced(blockArgumentElem.getArgNumber())) continue; detensorableBranchOps[terminator].insert( - blockOperands->getBeginOperandIndex() + - blockArgumentElem.getArgNumber()); + blockOperands.getOperandIndex(blockArgumentElem.getArgNumber())); } } @@ -343,14 +343,15 @@ auto ownerBlockOperands = predTerminator.getSuccessorOperands(pred.getSuccessorIndex()); - if (!ownerBlockOperands || ownerBlockOperands->empty()) + if (ownerBlockOperands.empty() || + ownerBlockOperands.isOperandProduced( + currentItemBlockArgument.getArgNumber())) continue; // For each predecessor, add the value it passes to that argument to // workList to find out how it's computed. workList.push_back( - ownerBlockOperands - .getValue()[currentItemBlockArgument.getArgNumber()]); + ownerBlockOperands[currentItemBlockArgument.getArgNumber()]); } continue; @@ -418,18 +419,16 @@ auto blockOperands = terminator.getSuccessorOperands(pred.getSuccessorIndex()); - if (!blockOperands || blockOperands->empty()) + if (blockOperands.empty() || + blockOperands.isOperandProduced(blockArg.getArgNumber())) continue; Operation *definingOp = - terminator - ->getOperand(blockOperands->getBeginOperandIndex() + - blockArg.getArgNumber()) - .getDefiningOp(); + blockOperands[blockArg.getArgNumber()].getDefiningOp(); // If the operand is defined by a GenericOp that will not be // detensored, then do not detensor the corresponding block argument. - if (dyn_cast_or_null(definingOp) && + if (isa_and_nonnull(definingOp) && opsToDetensor.count(definingOp) == 0) { blockArgsToRemove.insert(blockArg); break; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1515,21 +1515,20 @@ // spv.BranchOp //===----------------------------------------------------------------------===// -Optional -spirv::BranchOp::getMutableSuccessorOperands(unsigned index) { +SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return targetOperandsMutable(); + return SuccessorOperands(0, targetOperandsMutable()); } //===----------------------------------------------------------------------===// // spv.BranchConditionalOp //===----------------------------------------------------------------------===// -Optional -spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) { +SuccessorOperands +spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) { assert(index < 2 && "invalid successor index"); - return index == kTrueIndex ? trueTargetOperandsMutable() - : falseTargetOperandsMutable(); + return SuccessorOperands(index == kTrueIndex ? trueTargetOperandsMutable() + : falseTargetOperandsMutable()); } ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser, diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -18,6 +18,14 @@ #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc" +SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands) + : producedOperandCount(0), forwardedOperands(forwardedOperands) {} + +SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount, + MutableOperandRange forwardedOperands) + : producedOperandCount(producedOperandCount), + forwardedOperands(std::move(forwardedOperands)) {} + //===----------------------------------------------------------------------===// // BranchOpInterface //===----------------------------------------------------------------------===// @@ -26,32 +34,31 @@ /// successor if 'operandIndex' is within the range of 'operands', or None if /// `operandIndex` isn't a successor operand index. Optional -detail::getBranchSuccessorArgument(Optional operands, +detail::getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor) { + OperandRange forwardedOperands = operands.getForwardedOperands(); // Check that the operands are valid. - if (!operands || operands->empty()) + if (forwardedOperands.empty()) return llvm::None; // Check to ensure that this operand is within the range. - unsigned operandsStart = operands->getBeginOperandIndex(); + unsigned operandsStart = forwardedOperands.getBeginOperandIndex(); if (operandIndex < operandsStart || - operandIndex >= (operandsStart + operands->size())) + operandIndex >= (operandsStart + forwardedOperands.size())) return llvm::None; // Index the successor. - unsigned argIndex = operandIndex - operandsStart; + unsigned argIndex = + operands.getProducedOperandCount() + operandIndex - operandsStart; return successor->getArgument(argIndex); } /// Verify that the given operands match those of the given successor block. LogicalResult detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, - Optional operands) { - if (!operands) - return success(); - + const SuccessorOperands &operands) { // Check the count. - unsigned operandCount = operands->size(); + unsigned operandCount = operands.size(); Block *destBB = op->getSuccessor(succNo); if (operandCount != destBB->getNumArguments()) return op->emitError() << "branch has " << operandCount @@ -60,10 +67,10 @@ << destBB->getNumArguments(); // Check the types. - auto operandIt = operands->begin(); - for (unsigned i = 0; i != operandCount; ++i, ++operandIt) { + for (unsigned i = operands.getProducedOperandCount(); i != operandCount; + ++i) { if (!cast(op).areTypesCompatible( - (*operandIt).getType(), destBB->getArgument(i).getType())) + operands[i].getType(), destBB->getArgument(i).getType())) return op->emitError() << "type mismatch for bb argument #" << i << " of successor #" << succNo; } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -441,10 +441,9 @@ for (unsigned i = 0, e = terminator.getNumSuccessors(); i < e; ++i) { Block *successor = terminator.getSuccessor(i); auto branch = cast(terminator); - Optional successorOperands = branch.getSuccessorOperands(i); + SuccessorOperands successorOperands = branch.getSuccessorOperands(i); assert( - (!seenSuccessors.contains(successor) || - (successorOperands && successorOperands->empty())) && + (!seenSuccessors.contains(successor) || successorOperands.empty()) && "successors with arguments in LLVM branches must be different blocks"); seenSuccessors.insert(successor); } diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -223,12 +223,14 @@ return; } - // If we can't reason about the operands to a successor, conservatively mark - // all arguments as live. + // If we can't reason about the operand to a successor, conservatively mark + // it as live. for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { - if (!branchInterface.getMutableSuccessorOperands(i)) - for (BlockArgument arg : op->getSuccessor(i)->getArguments()) - liveMap.setProvedLive(arg); + SuccessorOperands successorOperands = + branchInterface.getSuccessorOperands(i); + for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount(); + opI != opE; ++opI) + liveMap.setProvedLive(op->getSuccessor(i)->getArgument(opI)); } } @@ -291,18 +293,15 @@ // since it will promote later operands of the terminator being erased // first, reducing the quadratic-ness. unsigned succ = succE - succI - 1; - Optional succOperands = - branchOp.getMutableSuccessorOperands(succ); - if (!succOperands) - continue; + SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ); Block *successor = terminator->getSuccessor(succ); - for (unsigned argI = 0, argE = succOperands->size(); argI < argE; ++argI) { + for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) { // Iterating args in reverse is needed for correctness, to avoid // shifting later args when earlier args are erased. unsigned arg = argE - argI - 1; if (!liveMap.wasProvenLive(successor->getArgument(arg))) - succOperands->erase(arg); + succOperands.erase(arg); } } } @@ -570,8 +569,7 @@ /// their operands updated. static bool ableToUpdatePredOperands(Block *block) { for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { - auto branch = dyn_cast((*it)->getTerminator()); - if (!branch || !branch.getMutableSuccessorOperands(it.getSuccessorIndex())) + if (!isa((*it)->getTerminator())) return false; } return true; @@ -631,7 +629,7 @@ predIt != predE; ++predIt) { auto branch = cast((*predIt)->getTerminator()); unsigned succIndex = predIt.getSuccessorIndex(); - branch.getMutableSuccessorOperands(succIndex)->append( + branch.getSuccessorOperands(succIndex).append( newArguments[clusterIndex]); } }; diff --git a/mlir/test/Transforms/sccp.mlir b/mlir/test/Transforms/sccp.mlir --- a/mlir/test/Transforms/sccp.mlir +++ b/mlir/test/Transforms/sccp.mlir @@ -198,3 +198,21 @@ // CHECK: return %[[X]], %[[Y]] return %x, %y : i1, i1 } + +// CHECK-LABEL: func @simple_produced_operand +func @simple_produced_operand() -> (i32, i32) { + // CHECK: %[[ONE:.*]] = arith.constant 1 + %1 = arith.constant 1 : i32 + "test.internal_br"(%1) [^bb1, ^bb2] { + operand_segment_sizes = dense<[0, 1]> : vector<2 x i32> + } : (i32) -> () + +^bb1: + cf.br ^bb2(%1, %1 : i32, i32) + +^bb2(%arg1 : i32, %arg2 : i32): + // CHECK: ^bb2(%[[ARG:.*]]: i32, %{{.*}}: i32): + // CHECK: return %[[ARG]], %[[ONE]] : i32, i32 + + return %arg1, %arg2 : i32, i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -335,22 +335,31 @@ // TestBranchOp //===----------------------------------------------------------------------===// -Optional -TestBranchOp::getMutableSuccessorOperands(unsigned index) { +SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return getTargetOperandsMutable(); + return SuccessorOperands(getTargetOperandsMutable()); } //===----------------------------------------------------------------------===// // TestProducingBranchOp //===----------------------------------------------------------------------===// -Optional -TestProducingBranchOp::getMutableSuccessorOperands(unsigned index) { +SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) { assert(index <= 1 && "invalid successor index"); if (index == 1) - return getFirstOperandsMutable(); - return getSecondOperandsMutable(); + return SuccessorOperands(getFirstOperandsMutable()); + return SuccessorOperands(getSecondOperandsMutable()); +} + +//===----------------------------------------------------------------------===// +// TestProducingBranchOp +//===----------------------------------------------------------------------===// + +SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { + assert(index <= 1 && "invalid successor index"); + if (index == 0) + return SuccessorOperands(0, getSuccessOperandsMutable()); + return SuccessorOperands(1, getErrorOperandsMutable()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -642,6 +642,17 @@ let successors = (successor AnySuccessor:$first,AnySuccessor:$second); } +// Produces an error value on the error path +def TestInternalBranchOp : TEST_Op<"internal_br", + [DeclareOpInterfaceMethods, Terminator, + AttrSizedOperandSegments]> { + + let arguments = (ins Variadic:$successOperands, + Variadic:$errorOperands); + + let successors = (successor AnySuccessor:$successPath, AnySuccessor:$errorPath); +} + def AttrSizedOperandOp : TEST_Op<"attr_sized_operands", [AttrSizedOperandSegments]> { let arguments = (ins