diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -585,7 +585,6 @@ llvm::Optional> getSuccessorOperands( llvm::ArrayRef operands, unsigned cond); - using BranchOpInterfaceTrait::getSuccessorOperands; // Helper function to deal with Optional operand forms void printSuccessorAtIndex(mlir::OpAsmPrinter &p, unsigned i) { diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -997,26 +997,14 @@ return "target_operand_offsets"; } -template +template static A getSubOperands(unsigned pos, A allArgs, - mlir::DenseIntElementsAttr ranges, - AdditionalArgs &&... additionalArgs) { + mlir::DenseIntElementsAttr ranges) { unsigned start = 0; for (unsigned i = 0; i < pos; ++i) start += (*(ranges.begin() + i)).getZExtValue(); - return allArgs.slice(start, (*(ranges.begin() + pos)).getZExtValue(), - std::forward(additionalArgs)...); -} - -static mlir::MutableOperandRange -getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands, - StringRef offsetAttr) { - Operation *owner = operands.getOwner(); - NamedAttribute targetOffsetAttr = - *owner->getMutableAttrDict().getNamed(offsetAttr); - return getSubOperands( - pos, operands, targetOffsetAttr.second.cast(), - mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); + unsigned end = start + (*(ranges.begin() + pos)).getZExtValue(); + return {std::next(allArgs.begin(), start), std::next(allArgs.begin(), end)}; } static unsigned denseElementsSize(mlir::DenseIntElementsAttr attr) { @@ -1032,10 +1020,10 @@ return {}; } -llvm::Optional -fir::SelectOp::getMutableSuccessorOperands(unsigned oper) { - return ::getMutableSuccessorOperands(oper, targetArgsMutable(), - getTargetOffsetAttr()); +llvm::Optional +fir::SelectOp::getSuccessorOperands(unsigned oper) { + auto a = getAttrOfType(getTargetOffsetAttr()); + return {getSubOperands(oper, targetArgs(), a)}; } llvm::Optional> @@ -1047,6 +1035,8 @@ return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } +bool fir::SelectOp::canEraseSuccessorOperand() { return true; } + unsigned fir::SelectOp::targetOffsetSize() { return denseElementsSize( getAttrOfType(getTargetOffsetAttr())); @@ -1071,10 +1061,10 @@ return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; } -llvm::Optional -fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) { - return ::getMutableSuccessorOperands(oper, targetArgsMutable(), - getTargetOffsetAttr()); +llvm::Optional +fir::SelectCaseOp::getSuccessorOperands(unsigned oper) { + auto a = getAttrOfType(getTargetOffsetAttr()); + return {getSubOperands(oper, targetArgs(), a)}; } llvm::Optional> @@ -1086,6 +1076,8 @@ return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } +bool fir::SelectCaseOp::canEraseSuccessorOperand() { return true; } + // parser for fir.select_case Op static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser, mlir::OperationState &result) { @@ -1262,10 +1254,10 @@ return {}; } -llvm::Optional -fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) { - return ::getMutableSuccessorOperands(oper, targetArgsMutable(), - getTargetOffsetAttr()); +llvm::Optional +fir::SelectRankOp::getSuccessorOperands(unsigned oper) { + auto a = getAttrOfType(getTargetOffsetAttr()); + return {getSubOperands(oper, targetArgs(), a)}; } llvm::Optional> @@ -1277,6 +1269,8 @@ return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } +bool fir::SelectRankOp::canEraseSuccessorOperand() { return true; } + unsigned fir::SelectRankOp::targetOffsetSize() { return denseElementsSize( getAttrOfType(getTargetOffsetAttr())); @@ -1296,10 +1290,10 @@ return {}; } -llvm::Optional -fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) { - return ::getMutableSuccessorOperands(oper, targetArgsMutable(), - getTargetOffsetAttr()); +llvm::Optional +fir::SelectTypeOp::getSuccessorOperands(unsigned oper) { + auto a = getAttrOfType(getTargetOffsetAttr()); + return {getSubOperands(oper, targetArgs(), a)}; } llvm::Optional> @@ -1311,6 +1305,8 @@ return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } +bool fir::SelectTypeOp::canEraseSuccessorOperand() { return true; } + static ParseResult parseSelectType(OpAsmParser &parser, OperationState &result) { mlir::OpAsmParser::OperandType selector; diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1072,7 +1072,7 @@ /// Erase the operand at 'index' from the true operand list. void eraseTrueOperand(unsigned index) { - trueDestOperandsMutable().erase(index); + eraseSuccessorOperand(trueIndex, index); } // Accessors for operands to the 'false' destination. @@ -1091,7 +1091,7 @@ /// Erase the operand at 'index' from the false operand list. void eraseFalseOperand(unsigned index) { - falseDestOperandsMutable().erase(index); + eraseSuccessorOperand(falseIndex, index); } private: diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -678,10 +678,6 @@ ArrayRef operandSegments = llvm::None); MutableOperandRange(Operation *owner); - /// Slice this range into a sub range, with the additional operand segment. - MutableOperandRange slice(unsigned subStart, unsigned subLen, - Optional segment = llvm::None); - /// Append the given values to the range. void append(ValueRange values); @@ -703,9 +699,6 @@ /// Allow implicit conversion to an OperandRange. operator OperandRange() const; - /// Returns the owning operation. - Operation *getOwner() const { return owner; } - private: /// Update the length of this range to the one provided. void updateLength(unsigned newLength); diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -24,6 +24,11 @@ //===----------------------------------------------------------------------===// namespace detail { +/// Erase an operand from a branch operation that is used as a successor +/// operand. `operandIndex` is the operand within `operands` to be erased. +void eraseBranchSuccessorOperand(OperandRange operands, unsigned operandIndex, + Operation *op); + /// Return the `BlockArgument` corresponding to operand `operandIndex` in some /// successor if `operandIndex` is within the range of `operands`, or None if /// `operandIndex` isn't a successor operand index. diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -27,25 +27,29 @@ }]; let methods = [ InterfaceMethod<[{ - Returns a mutable range of operands that correspond to the arguments of + Returns a set of values that correspond to the arguments to the successor at the given index. Returns None if the operands to the successor are non-materialized values, i.e. they are internal to the operation. }], - "Optional", "getMutableSuccessorOperands", - (ins "unsigned":$index) + "Optional", "getSuccessorOperands", (ins "unsigned":$index) >, InterfaceMethod<[{ - Returns a range of operands that correspond to the arguments of - successor at the given index. Returns None if the operands to the - successor are non-materialized values, i.e. they are internal to the - operation. + Return true if this operation can erase an operand to a successor block. + }], + "bool", "canEraseSuccessorOperand" + >, + InterfaceMethod<[{ + Erase the operand at `operandIndex` from the `index`-th successor. This + should only be called if `canEraseSuccessorOperand` returns true. }], - "Optional", "getSuccessorOperands", - (ins "unsigned":$index), [{}], [{ + "void", "eraseSuccessorOperand", + (ins "unsigned":$index, "unsigned":$operandIndex), [{}], + /*defaultImplementation=*/[{ ConcreteOp *op = static_cast(this); - auto operands = op->getMutableSuccessorOperands(index); - return operands ? Optional(*operands) : llvm::None; + Optional operands = op->getSuccessorOperands(index); + assert(operands && "unable to query operands for successor"); + detail::eraseBranchSuccessorOperand(*operands, operandIndex, *op); }] >, InterfaceMethod<[{ diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -160,22 +160,24 @@ // LLVM::BrOp //===----------------------------------------------------------------------===// -Optional -BrOp::getMutableSuccessorOperands(unsigned index) { +Optional BrOp::getSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return destOperandsMutable(); + return getOperands(); } +bool BrOp::canEraseSuccessorOperand() { return true; } + //===----------------------------------------------------------------------===// // LLVM::CondBrOp //===----------------------------------------------------------------------===// -Optional -CondBrOp::getMutableSuccessorOperands(unsigned index) { +Optional CondBrOp::getSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); - return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable(); + return index == 0 ? trueDestOperands() : falseDestOperands(); } +bool CondBrOp::canEraseSuccessorOperand() { return true; } + //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::LoadOp. //===----------------------------------------------------------------------===// @@ -255,12 +257,13 @@ /// LLVM::InvokeOp ///===---------------------------------------------------------------------===// -Optional -InvokeOp::getMutableSuccessorOperands(unsigned index) { +Optional InvokeOp::getSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); - return index == 0 ? normalDestOperandsMutable() : unwindDestOperandsMutable(); + return index == 0 ? normalDestOperands() : unwindDestOperands(); } +bool InvokeOp::canEraseSuccessorOperand() { return true; } + static LogicalResult verify(InvokeOp op) { if (op.getNumResults() > 1) return op.emitOpError("must have 0 or 1 result"); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -987,23 +987,26 @@ // spv.BranchOp //===----------------------------------------------------------------------===// -Optional -spirv::BranchOp::getMutableSuccessorOperands(unsigned index) { +Optional spirv::BranchOp::getSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return targetOperandsMutable(); + return getOperands(); } +bool spirv::BranchOp::canEraseSuccessorOperand() { return true; } + //===----------------------------------------------------------------------===// // spv.BranchConditionalOp //===----------------------------------------------------------------------===// -Optional -spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) { +Optional +spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) { assert(index < 2 && "invalid successor index"); - return index == kTrueIndex ? trueTargetOperandsMutable() - : falseTargetOperandsMutable(); + return index == kTrueIndex ? getTrueBlockArguments() + : getFalseBlockArguments(); } +bool spirv::BranchConditionalOp::canEraseSuccessorOperand() { return true; } + static ParseResult parseBranchConditionalOp(OpAsmParser &parser, OperationState &state) { auto &builder = parser.getBuilder(); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -677,12 +677,13 @@ context); } -Optional -BranchOp::getMutableSuccessorOperands(unsigned index) { +Optional BranchOp::getSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return destOperandsMutable(); + return getOperands(); } +bool BranchOp::canEraseSuccessorOperand() { return true; } + Block *BranchOp::getSuccessorForOperands(ArrayRef) { return dest(); } //===----------------------------------------------------------------------===// @@ -1020,13 +1021,13 @@ SimplifyCondBranchIdenticalSuccessors>(context); } -Optional -CondBranchOp::getMutableSuccessorOperands(unsigned index) { +Optional CondBranchOp::getSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); - return index == trueIndex ? trueDestOperandsMutable() - : falseDestOperandsMutable(); + return index == trueIndex ? getTrueOperands() : getFalseOperands(); } +bool CondBranchOp::canEraseSuccessorOperand() { return true; } + Block *CondBranchOp::getSuccessorForOperands(ArrayRef operands) { if (BoolAttr condAttr = operands.front().dyn_cast_or_null()) return condAttr.getValue() ? trueDest() : falseDest(); diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -287,18 +287,6 @@ MutableOperandRange::MutableOperandRange(Operation *owner) : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {} -/// Slice this range into a sub range, with the additional operand segment. -MutableOperandRange -MutableOperandRange::slice(unsigned subStart, unsigned subLen, - Optional segment) { - assert((subStart + subLen) <= length && "invalid sub-range"); - MutableOperandRange subSlice(owner, start + subStart, subLen, - operandSegments); - if (segment) - subSlice.operandSegments.push_back(*segment); - return subSlice; -} - /// Append the given values to the range. void MutableOperandRange::append(ValueRange values) { if (values.empty()) diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -21,6 +21,39 @@ // BranchOpInterface //===----------------------------------------------------------------------===// +/// Erase an operand from a branch operation that is used as a successor +/// operand. 'operandIndex' is the operand within 'operands' to be erased. +void mlir::detail::eraseBranchSuccessorOperand(OperandRange operands, + unsigned operandIndex, + Operation *op) { + assert(operandIndex < operands.size() && + "invalid index for successor operands"); + + // Erase the operand from the operation. + size_t fullOperandIndex = operands.getBeginOperandIndex() + operandIndex; + op->eraseOperand(fullOperandIndex); + + // If this operation has an OperandSegmentSizeAttr, keep it up to date. + auto operandSegmentAttr = + op->getAttrOfType("operand_segment_sizes"); + if (!operandSegmentAttr) + return; + + // Find the segment containing the full operand index and decrement it. + // TODO: This seems like a general utility that could be added somewhere. + SmallVector values(operandSegmentAttr.getValues()); + unsigned currentSize = 0; + for (unsigned i = 0, e = values.size(); i != e; ++i) { + currentSize += values[i]; + if (fullOperandIndex < currentSize) { + --values[i]; + break; + } + } + op->setAttr("operand_segment_sizes", + DenseIntElementsAttr::get(operandSegmentAttr.getType(), values)); +} + /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some /// successor if 'operandIndex' is within the range of 'operands', or None if /// `operandIndex` isn't a successor operand index. diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -209,7 +209,7 @@ // Check to see if we can reason about the successor operands and mutate them. BranchOpInterface branchInterface = dyn_cast(op); - if (!branchInterface) { + if (!branchInterface || !branchInterface.canEraseSuccessorOperand()) { for (Block *successor : op->getSuccessors()) for (BlockArgument arg : successor->getArguments()) liveMap.setProvedLive(arg); @@ -219,7 +219,7 @@ // If we can't reason about the operands to a successor, conservatively mark // all arguments as live. for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { - if (!branchInterface.getMutableSuccessorOperands(i)) + if (!branchInterface.getSuccessorOperands(i)) for (BlockArgument arg : op->getSuccessor(i)->getArguments()) liveMap.setProvedLive(arg); } @@ -278,8 +278,7 @@ // since it will promote later operands of the terminator being erased // first, reducing the quadratic-ness. unsigned succ = succE - succI - 1; - Optional succOperands = - branchOp.getMutableSuccessorOperands(succ); + Optional succOperands = branchOp.getSuccessorOperands(succ); if (!succOperands) continue; Block *successor = terminator->getSuccessor(succ); @@ -289,7 +288,7 @@ // shifting later args when earlier args are erased. unsigned arg = argE - argI - 1; if (!liveMap.wasProvenLive(successor->getArgument(arg))) - succOperands->erase(arg); + branchOp.eraseSuccessorOperand(succ, arg); } } } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -167,12 +167,13 @@ // TestBranchOp //===----------------------------------------------------------------------===// -Optional -TestBranchOp::getMutableSuccessorOperands(unsigned index) { +Optional TestBranchOp::getSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return targetOperandsMutable(); + return getOperands(); } +bool TestBranchOp::canEraseSuccessorOperand() { return true; } + //===----------------------------------------------------------------------===// // Test IsolatedRegionOp - parse passthrough region arguments. //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -146,7 +146,7 @@ StringRef interfaceName, StringRef interfaceTraitsName) { os << " template \n " - << llvm::formatv("struct {0}Trait : public OpInterface<{0}," + << llvm::formatv("struct Trait : public OpInterface<{0}," " detail::{1}>::Trait {{\n", interfaceName, interfaceTraitsName); @@ -171,17 +171,13 @@ tblgen::FmtContext traitCtx; traitCtx.withOp("op"); if (auto verify = interface.getVerify()) { - os << " static LogicalResult verifyTrait(Operation* op) {\n" + os << " static LogicalResult verifyTrait(Operation* op) {\n" << std::string(tblgen::tgfmt(*verify, &traitCtx)) << "\n }\n"; } if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration()) os << extraTraitDecls << "\n"; os << " };\n"; - - // Emit a utility using directive for the trait class. - os << " template \n " - << llvm::formatv("using Trait = {0}Trait;\n", interfaceName); } static void emitInterfaceDecl(OpInterface &interface, raw_ostream &os) {