diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -585,6 +585,7 @@ llvm::Optional> getSuccessorOperands( llvm::ArrayRef operands, unsigned cond); + using BranchOpInterfaceTrait::getSuccessorOperands; // Helper function to deal with Optional operand forms void printSuccessorAtIndex(mlir::OpAsmPrinter &p, unsigned i) { diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -997,14 +997,26 @@ return "target_operand_offsets"; } -template +template static A getSubOperands(unsigned pos, A allArgs, - mlir::DenseIntElementsAttr ranges) { + mlir::DenseIntElementsAttr ranges, + AdditionalArgs &&... additionalArgs) { unsigned start = 0; for (unsigned i = 0; i < pos; ++i) start += (*(ranges.begin() + i)).getZExtValue(); - unsigned end = start + (*(ranges.begin() + pos)).getZExtValue(); - return {std::next(allArgs.begin(), start), std::next(allArgs.begin(), end)}; + return allArgs.slice(start, (*(ranges.begin() + pos)).getZExtValue(), + std::forward(additionalArgs)...); +} + +static mlir::MutableOperandRange +getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands, + StringRef offsetAttr) { + Operation *owner = operands.getOwner(); + NamedAttribute targetOffsetAttr = + *owner->getMutableAttrDict().getNamed(offsetAttr); + return getSubOperands( + pos, operands, targetOffsetAttr.second.cast(), + mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); } static unsigned denseElementsSize(mlir::DenseIntElementsAttr attr) { @@ -1020,10 +1032,10 @@ return {}; } -llvm::Optional -fir::SelectOp::getSuccessorOperands(unsigned oper) { - auto a = getAttrOfType(getTargetOffsetAttr()); - return {getSubOperands(oper, targetArgs(), a)}; +llvm::Optional +fir::SelectOp::getMutableSuccessorOperands(unsigned oper) { + return ::getMutableSuccessorOperands(oper, targetArgsMutable(), + getTargetOffsetAttr()); } llvm::Optional> @@ -1035,8 +1047,6 @@ return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } -bool fir::SelectOp::canEraseSuccessorOperand() { return true; } - unsigned fir::SelectOp::targetOffsetSize() { return denseElementsSize( getAttrOfType(getTargetOffsetAttr())); @@ -1061,10 +1071,10 @@ return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; } -llvm::Optional -fir::SelectCaseOp::getSuccessorOperands(unsigned oper) { - auto a = getAttrOfType(getTargetOffsetAttr()); - return {getSubOperands(oper, targetArgs(), a)}; +llvm::Optional +fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) { + return ::getMutableSuccessorOperands(oper, targetArgsMutable(), + getTargetOffsetAttr()); } llvm::Optional> @@ -1076,8 +1086,6 @@ return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } -bool fir::SelectCaseOp::canEraseSuccessorOperand() { return true; } - // parser for fir.select_case Op static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser, mlir::OperationState &result) { @@ -1254,10 +1262,10 @@ return {}; } -llvm::Optional -fir::SelectRankOp::getSuccessorOperands(unsigned oper) { - auto a = getAttrOfType(getTargetOffsetAttr()); - return {getSubOperands(oper, targetArgs(), a)}; +llvm::Optional +fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) { + return ::getMutableSuccessorOperands(oper, targetArgsMutable(), + getTargetOffsetAttr()); } llvm::Optional> @@ -1269,8 +1277,6 @@ return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } -bool fir::SelectRankOp::canEraseSuccessorOperand() { return true; } - unsigned fir::SelectRankOp::targetOffsetSize() { return denseElementsSize( getAttrOfType(getTargetOffsetAttr())); @@ -1290,10 +1296,10 @@ return {}; } -llvm::Optional -fir::SelectTypeOp::getSuccessorOperands(unsigned oper) { - auto a = getAttrOfType(getTargetOffsetAttr()); - return {getSubOperands(oper, targetArgs(), a)}; +llvm::Optional +fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) { + return ::getMutableSuccessorOperands(oper, targetArgsMutable(), + getTargetOffsetAttr()); } llvm::Optional> @@ -1305,8 +1311,6 @@ return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } -bool fir::SelectTypeOp::canEraseSuccessorOperand() { return true; } - static ParseResult parseSelectType(OpAsmParser &parser, OperationState &result) { mlir::OpAsmParser::OperandType selector; diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1074,7 +1074,7 @@ /// Erase the operand at 'index' from the true operand list. void eraseTrueOperand(unsigned index) { - eraseSuccessorOperand(trueIndex, index); + trueDestOperandsMutable().erase(index); } // Accessors for operands to the 'false' destination. @@ -1093,7 +1093,7 @@ /// Erase the operand at 'index' from the false operand list. void eraseFalseOperand(unsigned index) { - eraseSuccessorOperand(falseIndex, index); + falseDestOperandsMutable().erase(index); } private: diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -678,6 +678,10 @@ ArrayRef operandSegments = llvm::None); MutableOperandRange(Operation *owner); + /// Slice this range into a sub range, with the additional operand segment. + MutableOperandRange slice(unsigned subStart, unsigned subLen, + Optional segment = llvm::None); + /// Append the given values to the range. void append(ValueRange values); @@ -699,6 +703,9 @@ /// Allow implicit conversion to an OperandRange. operator OperandRange() const; + /// Returns the owning operation. + Operation *getOwner() const { return owner; } + private: /// Update the length of this range to the one provided. void updateLength(unsigned newLength); diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -24,11 +24,6 @@ //===----------------------------------------------------------------------===// namespace detail { -/// Erase an operand from a branch operation that is used as a successor -/// operand. `operandIndex` is the operand within `operands` to be erased. -void eraseBranchSuccessorOperand(OperandRange operands, unsigned operandIndex, - Operation *op); - /// Return the `BlockArgument` corresponding to operand `operandIndex` in some /// successor if `operandIndex` is within the range of `operands`, or None if /// `operandIndex` isn't a successor operand index. diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -27,29 +27,25 @@ }]; let methods = [ InterfaceMethod<[{ - Returns a set of values that correspond to the arguments to the + Returns a mutable range of operands that correspond to the arguments of successor at the given index. Returns None if the operands to the successor are non-materialized values, i.e. they are internal to the operation. }], - "Optional", "getSuccessorOperands", (ins "unsigned":$index) + "Optional", "getMutableSuccessorOperands", + (ins "unsigned":$index) >, InterfaceMethod<[{ - Return true if this operation can erase an operand to a successor block. - }], - "bool", "canEraseSuccessorOperand" - >, - InterfaceMethod<[{ - Erase the operand at `operandIndex` from the `index`-th successor. This - should only be called if `canEraseSuccessorOperand` returns true. + Returns a range of operands that correspond to the arguments of + successor at the given index. Returns None if the operands to the + successor are non-materialized values, i.e. they are internal to the + operation. }], - "void", "eraseSuccessorOperand", - (ins "unsigned":$index, "unsigned":$operandIndex), [{}], - /*defaultImplementation=*/[{ + "Optional", "getSuccessorOperands", + (ins "unsigned":$index), [{}], [{ ConcreteOp *op = static_cast(this); - Optional operands = op->getSuccessorOperands(index); - assert(operands && "unable to query operands for successor"); - detail::eraseBranchSuccessorOperand(*operands, operandIndex, *op); + auto operands = op->getMutableSuccessorOperands(index); + return operands ? Optional(*operands) : llvm::None; }] >, InterfaceMethod<[{ diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -160,24 +160,22 @@ // LLVM::BrOp //===----------------------------------------------------------------------===// -Optional BrOp::getSuccessorOperands(unsigned index) { +Optional +BrOp::getMutableSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return getOperands(); + return destOperandsMutable(); } -bool BrOp::canEraseSuccessorOperand() { return true; } - //===----------------------------------------------------------------------===// // LLVM::CondBrOp //===----------------------------------------------------------------------===// -Optional CondBrOp::getSuccessorOperands(unsigned index) { +Optional +CondBrOp::getMutableSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); - return index == 0 ? trueDestOperands() : falseDestOperands(); + return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable(); } -bool CondBrOp::canEraseSuccessorOperand() { return true; } - //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::LoadOp. //===----------------------------------------------------------------------===// @@ -257,13 +255,12 @@ /// LLVM::InvokeOp ///===---------------------------------------------------------------------===// -Optional InvokeOp::getSuccessorOperands(unsigned index) { +Optional +InvokeOp::getMutableSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); - return index == 0 ? normalDestOperands() : unwindDestOperands(); + return index == 0 ? normalDestOperandsMutable() : unwindDestOperandsMutable(); } -bool InvokeOp::canEraseSuccessorOperand() { return true; } - static LogicalResult verify(InvokeOp op) { if (op.getNumResults() > 1) return op.emitOpError("must have 0 or 1 result"); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -987,26 +987,23 @@ // spv.BranchOp //===----------------------------------------------------------------------===// -Optional spirv::BranchOp::getSuccessorOperands(unsigned index) { +Optional +spirv::BranchOp::getMutableSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return getOperands(); + return targetOperandsMutable(); } -bool spirv::BranchOp::canEraseSuccessorOperand() { return true; } - //===----------------------------------------------------------------------===// // spv.BranchConditionalOp //===----------------------------------------------------------------------===// -Optional -spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) { +Optional +spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) { assert(index < 2 && "invalid successor index"); - return index == kTrueIndex ? getTrueBlockArguments() - : getFalseBlockArguments(); + return index == kTrueIndex ? trueTargetOperandsMutable() + : falseTargetOperandsMutable(); } -bool spirv::BranchConditionalOp::canEraseSuccessorOperand() { return true; } - static ParseResult parseBranchConditionalOp(OpAsmParser &parser, OperationState &state) { auto &builder = parser.getBuilder(); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -677,13 +677,12 @@ context); } -Optional BranchOp::getSuccessorOperands(unsigned index) { +Optional +BranchOp::getMutableSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return getOperands(); + return destOperandsMutable(); } -bool BranchOp::canEraseSuccessorOperand() { return true; } - Block *BranchOp::getSuccessorForOperands(ArrayRef) { return dest(); } //===----------------------------------------------------------------------===// @@ -1021,13 +1020,13 @@ SimplifyCondBranchIdenticalSuccessors>(context); } -Optional CondBranchOp::getSuccessorOperands(unsigned index) { +Optional +CondBranchOp::getMutableSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); - return index == trueIndex ? getTrueOperands() : getFalseOperands(); + return index == trueIndex ? trueDestOperandsMutable() + : falseDestOperandsMutable(); } -bool CondBranchOp::canEraseSuccessorOperand() { return true; } - Block *CondBranchOp::getSuccessorForOperands(ArrayRef operands) { if (BoolAttr condAttr = operands.front().dyn_cast_or_null()) return condAttr.getValue() ? trueDest() : falseDest(); diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -287,6 +287,18 @@ MutableOperandRange::MutableOperandRange(Operation *owner) : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {} +/// Slice this range into a sub range, with the additional operand segment. +MutableOperandRange +MutableOperandRange::slice(unsigned subStart, unsigned subLen, + Optional segment) { + assert((subStart + subLen) <= length && "invalid sub-range"); + MutableOperandRange subSlice(owner, start + subStart, subLen, + operandSegments); + if (segment) + subSlice.operandSegments.push_back(*segment); + return subSlice; +} + /// Append the given values to the range. void MutableOperandRange::append(ValueRange values) { if (values.empty()) diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -21,39 +21,6 @@ // BranchOpInterface //===----------------------------------------------------------------------===// -/// Erase an operand from a branch operation that is used as a successor -/// operand. 'operandIndex' is the operand within 'operands' to be erased. -void mlir::detail::eraseBranchSuccessorOperand(OperandRange operands, - unsigned operandIndex, - Operation *op) { - assert(operandIndex < operands.size() && - "invalid index for successor operands"); - - // Erase the operand from the operation. - size_t fullOperandIndex = operands.getBeginOperandIndex() + operandIndex; - op->eraseOperand(fullOperandIndex); - - // If this operation has an OperandSegmentSizeAttr, keep it up to date. - auto operandSegmentAttr = - op->getAttrOfType("operand_segment_sizes"); - if (!operandSegmentAttr) - return; - - // Find the segment containing the full operand index and decrement it. - // TODO: This seems like a general utility that could be added somewhere. - SmallVector values(operandSegmentAttr.getValues()); - unsigned currentSize = 0; - for (unsigned i = 0, e = values.size(); i != e; ++i) { - currentSize += values[i]; - if (fullOperandIndex < currentSize) { - --values[i]; - break; - } - } - op->setAttr("operand_segment_sizes", - DenseIntElementsAttr::get(operandSegmentAttr.getType(), values)); -} - /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some /// successor if 'operandIndex' is within the range of 'operands', or None if /// `operandIndex` isn't a successor operand index. diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -209,7 +209,7 @@ // Check to see if we can reason about the successor operands and mutate them. BranchOpInterface branchInterface = dyn_cast(op); - if (!branchInterface || !branchInterface.canEraseSuccessorOperand()) { + if (!branchInterface) { for (Block *successor : op->getSuccessors()) for (BlockArgument arg : successor->getArguments()) liveMap.setProvedLive(arg); @@ -219,7 +219,7 @@ // If we can't reason about the operands to a successor, conservatively mark // all arguments as live. for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { - if (!branchInterface.getSuccessorOperands(i)) + if (!branchInterface.getMutableSuccessorOperands(i)) for (BlockArgument arg : op->getSuccessor(i)->getArguments()) liveMap.setProvedLive(arg); } @@ -278,7 +278,8 @@ // since it will promote later operands of the terminator being erased // first, reducing the quadratic-ness. unsigned succ = succE - succI - 1; - Optional succOperands = branchOp.getSuccessorOperands(succ); + Optional succOperands = + branchOp.getMutableSuccessorOperands(succ); if (!succOperands) continue; Block *successor = terminator->getSuccessor(succ); @@ -288,7 +289,7 @@ // shifting later args when earlier args are erased. unsigned arg = argE - argI - 1; if (!liveMap.wasProvenLive(successor->getArgument(arg))) - branchOp.eraseSuccessorOperand(succ, arg); + succOperands->erase(arg); } } } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -167,13 +167,12 @@ // TestBranchOp //===----------------------------------------------------------------------===// -Optional TestBranchOp::getSuccessorOperands(unsigned index) { +Optional +TestBranchOp::getMutableSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return getOperands(); + return targetOperandsMutable(); } -bool TestBranchOp::canEraseSuccessorOperand() { return true; } - //===----------------------------------------------------------------------===// // Test IsolatedRegionOp - parse passthrough region arguments. //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -146,7 +146,7 @@ StringRef interfaceName, StringRef interfaceTraitsName) { os << " template \n " - << llvm::formatv("struct Trait : public OpInterface<{0}," + << llvm::formatv("struct {0}Trait : public OpInterface<{0}," " detail::{1}>::Trait {{\n", interfaceName, interfaceTraitsName); @@ -171,13 +171,17 @@ tblgen::FmtContext traitCtx; traitCtx.withOp("op"); if (auto verify = interface.getVerify()) { - os << " static LogicalResult verifyTrait(Operation* op) {\n" + os << " static LogicalResult verifyTrait(Operation* op) {\n" << std::string(tblgen::tgfmt(*verify, &traitCtx)) << "\n }\n"; } if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration()) os << extraTraitDecls << "\n"; os << " };\n"; + + // Emit a utility using directive for the trait class. + os << " template \n " + << llvm::formatv("using Trait = {0}Trait;\n", interfaceName); } static void emitInterfaceDecl(OpInterface &interface, raw_ostream &os) {