diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -70,40 +70,20 @@ LLVM_Op, Results<(outs)>, LLVM_TwoBuilders; +// Opaque builder used for terminator operations that contain successors. +def LLVM_TerminatorPassthroughOpBuilder : OpBuilder< + "Builder *, OperationState &result, ValueRange operands, " + "SuccessorRange destinations, ArrayRef attributes = {}", + [{ + result.addOperands(operands); + result.addSuccessors(destinations); + result.addAttributes(attributes); + }]>; + // Base class for LLVM terminator operations. All terminator operations have // zero results and an optional list of successors. class LLVM_TerminatorOp traits = []> : - LLVM_Op { - let builders = [ - OpBuilder< - "Builder *, OperationState &result, " - "ValueRange properOperands, " - "ArrayRef destinations, " - "ArrayRef operands, " - "ArrayRef attributes = {}", - [{ - result.addOperands(properOperands); - for (auto kvp : llvm::zip(destinations, operands)) { - result.addSuccessor(std::get<0>(kvp), std::get<1>(kvp)); - } - for (auto namedAttr : attributes) { - result.addAttribute(namedAttr.first, namedAttr.second); - } - }] - >, - OpBuilder< - "Builder *builder, OperationState &result, " - "ValueRange properOperands, " - "ArrayRef destinations, " - "ArrayRef attributes = {}", - [{ - SmallVector operands(destinations.size(), {}); - build(builder, result, properOperands, - destinations, operands, attributes); - }] - >, - ]; -} + LLVM_Op; // Class for arithmetic binary operations. class LLVM_ArithmeticOp, Terminator + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + Terminator ]>, Arguments<(ins OptionalAttr:$callee, - Variadic)>, + Variadic:$operands, + Variadic:$normalDestOperands, + Variadic:$unwindDestOperands)>, Results<(outs Variadic)> { let successors = (successor AnySuccessor:$normalDest, AnySuccessor:$unwindDest); let builders = [OpBuilder< + "Builder *b, OperationState &result, ArrayRef tys, " + "FlatSymbolRefAttr callee, ValueRange ops, Block* normal, " + "ValueRange normalOps, Block* unwind, ValueRange unwindOps", + [{ + result.addAttribute("callee", callee); + build(b, result, tys, ops, normal, normalOps, unwind, unwindOps); + }]>, + OpBuilder< "Builder *b, OperationState &result, ArrayRef tys, " "ValueRange ops, Block* normal, " "ValueRange normalOps, Block* unwind, ValueRange unwindOps", [{ - result.addTypes(tys); - result.addOperands(ops); - result.addSuccessor(normal, normalOps); - result.addSuccessor(unwind, unwindOps); + build(b, result, tys, /*callee=*/FlatSymbolRefAttr(), ops, normalOps, + unwindOps, normal, unwind); }]>]; let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseInvokeOp(parser, result); }]; @@ -463,14 +453,38 @@ // Terminators. def LLVM_BrOp : LLVM_TerminatorOp<"br", [DeclareOpInterfaceMethods]> { + let arguments = (ins Variadic:$destOperands); let successors = (successor AnySuccessor:$dest); - let assemblyFormat = "$dest attr-dict"; + let assemblyFormat = [{ + $dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict + }]; + let builders = [LLVM_TerminatorPassthroughOpBuilder]; } def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", - [DeclareOpInterfaceMethods]> { - let arguments = (ins LLVMI1:$condition); + [AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let arguments = (ins LLVMI1:$condition, + Variadic:$trueDestOperands, + Variadic:$falseDestOperands); let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); - let assemblyFormat = "$condition `,` successors attr-dict"; + let assemblyFormat = [{ + $condition `,` + $trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,` + $falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)? + attr-dict + }]; + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value condition," + "Block *trueDest, ValueRange trueOperands," + "Block *falseDest, ValueRange falseOperands", [{ + build(builder, result, condition, trueOperands, falseOperands, trueDest, + falseDest); + }]>, OpBuilder< + "Builder *builder, OperationState &result, Value condition," + "Block *trueDest, Block *falseDest, ValueRange falseOperands = {}", [{ + build(builder, result, condition, trueDest, ValueRange(), falseDest, + falseOperands); + }]>, LLVM_TerminatorPassthroughOpBuilder]; } def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []>, Arguments<(ins Variadic:$args)> { diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -20,7 +20,7 @@ // ----- -def SPV_BranchOp : SPV_Op<"Branch",[ +def SPV_BranchOp : SPV_Op<"Branch", [ DeclareOpInterfaceMethods, InFunctionScope, Terminator]> { let summary = "Unconditional branch to target block."; @@ -44,7 +44,7 @@ ``` }]; - let arguments = (ins); + let arguments = (ins Variadic:$targetOperands); let results = (outs); @@ -56,7 +56,8 @@ OpBuilder< "Builder *, OperationState &state, " "Block *successor, ValueRange arguments = {}", [{ - state.addSuccessor(successor, arguments); + state.addSuccessors(successor); + state.addOperands(arguments); }] > ]; @@ -73,14 +74,16 @@ let autogenSerialization = 0; - let assemblyFormat = "successors attr-dict"; + let assemblyFormat = [{ + $target (`(` $targetOperands^ `:` type($targetOperands) `)`)? attr-dict + }]; } // ----- def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [ - DeclareOpInterfaceMethods, InFunctionScope, - Terminator]> { + AttrSizedOperandSegments, DeclareOpInterfaceMethods, + InFunctionScope, Terminator]> { let summary = [{ If Condition is true, branch to true block, otherwise branch to false block. @@ -121,13 +124,15 @@ let arguments = (ins SPV_Bool:$condition, + Variadic:$trueTargetOperands, + Variadic:$falseTargetOperands, OptionalAttr:$branch_weights ); let results = (outs); let successors = (successor AnySuccessor:$trueTarget, - AnySuccessor:$falseTarget); + AnySuccessor:$falseTarget); let builders = [ OpBuilder< @@ -136,21 +141,18 @@ "Block *falseBlock, ValueRange falseArguments, " "Optional> weights = {}", [{ - state.addOperands(condition); - state.addSuccessor(trueBlock, trueArguments); - state.addSuccessor(falseBlock, falseArguments); + ArrayAttr weightsAttr; if (weights) { - auto attr = + weightsAttr = builder->getI32ArrayAttr({static_cast(weights->first), static_cast(weights->second)}); - state.addAttribute("branch_weights", attr); } + build(builder, state, condition, trueArguments, falseArguments, + weightsAttr, trueBlock, falseBlock); }] > ]; - let skipDefaultBuilders = 1; - let autogenSerialization = 0; let extraClassDeclaration = [{ @@ -165,34 +167,22 @@ /// Returns the number of arguments to the true target block. unsigned getNumTrueBlockArguments() { - return getNumSuccessorOperands(kTrueIndex); + return trueTargetOperands().size(); } /// Returns the number of arguments to the false target block. unsigned getNumFalseBlockArguments() { - return getNumSuccessorOperands(kFalseIndex); + return falseTargetOperands().size(); } // Iterator and range support for true target block arguments. - operand_iterator true_block_argument_begin() { - return operand_begin() + getTrueBlockArgumentIndex(); - } - operand_iterator true_block_argument_end() { - return true_block_argument_begin() + getNumTrueBlockArguments(); - } operand_range getTrueBlockArguments() { - return {true_block_argument_begin(), true_block_argument_end()}; + return trueTargetOperands(); } // Iterator and range support for false target block arguments. - operand_iterator false_block_argument_begin() { - return true_block_argument_end(); - } - operand_iterator false_block_argument_end() { - return false_block_argument_begin() + getNumFalseBlockArguments(); - } operand_range getFalseBlockArguments() { - return {false_block_argument_begin(), false_block_argument_end()}; + return falseTargetOperands(); } private: diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -347,10 +347,13 @@ ^bb3(%3: tensor<*xf32>): }]; + let arguments = (ins Variadic:$destOperands); let successors = (successor AnySuccessor:$dest); - let builders = [OpBuilder<"Builder *, OperationState &result, Block *dest", [{ - result.addSuccessor(dest, llvm::None); + let builders = [OpBuilder<"Builder *, OperationState &result, Block *dest, " + "ValueRange destOperands = {}", [{ + result.addSuccessors(dest); + result.addOperands(destOperands); }]>]; // BranchOp is fully verified by traits. @@ -365,7 +368,9 @@ }]; let hasCanonicalizer = 1; - let assemblyFormat = "$dest attr-dict"; + let assemblyFormat = [{ + $dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict + }]; } //===----------------------------------------------------------------------===// @@ -671,7 +676,8 @@ //===----------------------------------------------------------------------===// def CondBranchOp : Std_Op<"cond_br", - [DeclareOpInterfaceMethods, Terminator]> { + [AttrSizedOperandSegments, DeclareOpInterfaceMethods, + Terminator]> { let summary = "conditional branch operation"; let description = [{ The "cond_br" operation represents a conditional branch operation in a @@ -688,9 +694,24 @@ ... }]; - let arguments = (ins I1:$condition); + let arguments = (ins I1:$condition, + Variadic:$trueDestOperands, + Variadic:$falseDestOperands); let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value condition," + "Block *trueDest, ValueRange trueOperands," + "Block *falseDest, ValueRange falseOperands", [{ + build(builder, result, condition, trueOperands, falseOperands, trueDest, + falseDest); + }]>, OpBuilder< + "Builder *builder, OperationState &result, Value condition," + "Block *trueDest, Block *falseDest, ValueRange falseOperands = {}", [{ + build(builder, result, condition, trueDest, ValueRange(), falseDest, + falseOperands); + }]>]; + // CondBranchOp is fully verified by traits. let verifier = ?; @@ -722,23 +743,13 @@ setOperand(getTrueDestOperandIndex() + idx, value); } - operand_iterator true_operand_begin() { - return operand_begin() + getTrueDestOperandIndex(); - } - operand_iterator true_operand_end() { - return true_operand_begin() + getNumTrueOperands(); - } - operand_range getTrueOperands() { - return {true_operand_begin(), true_operand_end()}; - } + operand_range getTrueOperands() { return trueDestOperands(); } - unsigned getNumTrueOperands() { - return getNumSuccessorOperands(trueIndex); - } + unsigned getNumTrueOperands() { return getTrueOperands().size(); } /// Erase the operand at 'index' from the true operand list. void eraseTrueOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(trueIndex, index); + eraseSuccessorOperand(trueIndex, index); } // Accessors for operands to the 'false' destination. @@ -751,21 +762,13 @@ setOperand(getFalseDestOperandIndex() + idx, value); } - operand_iterator false_operand_begin() { return true_operand_end(); } - operand_iterator false_operand_end() { - return false_operand_begin() + getNumFalseOperands(); - } - operand_range getFalseOperands() { - return {false_operand_begin(), false_operand_end()}; - } + operand_range getFalseOperands() { return falseDestOperands(); } - unsigned getNumFalseOperands() { - return getNumSuccessorOperands(falseIndex); - } + unsigned getNumFalseOperands() { return getFalseOperands().size(); } /// Erase the operand at 'index' from the false operand list. void eraseFalseOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(falseIndex, index); + eraseSuccessorOperand(falseIndex, index); } private: @@ -779,7 +782,12 @@ }]; let hasCanonicalizer = 1; - let assemblyFormat = "$condition `,` successors attr-dict"; + let assemblyFormat = [{ + $condition `,` + $trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,` + $falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)? + attr-dict + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -668,10 +668,6 @@ static LogicalResult verifyTrait(Operation *op) { return impl::verifyIsTerminator(op); } - - unsigned getNumSuccessorOperands(unsigned index) { - return this->getOperation()->getNumSuccessorOperands(index); - } }; /// This class provides verification for ops that are known to have zero diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -62,9 +62,12 @@ /// provide a valid type for the attribute. virtual void printAttributeWithoutType(Attribute attr) = 0; - /// Print a successor, and use list, of a terminator operation given the - /// terminator and the successor index. - virtual void printSuccessorAndUseList(Operation *term, unsigned index) = 0; + /// Print the given successor. + virtual void printSuccessor(Block *successor) = 0; + + /// Print the successor and its operands. + virtual void printSuccessorAndUseList(Block *successor, + ValueRange succOperands) = 0; /// If the specified operation has attributes, print out an attribute /// dictionary with their values. elidedAttrs allows the client to ignore @@ -120,8 +123,7 @@ /// Print the complete type of an operation in functional form. void printFunctionalType(Operation *op) { - printFunctionalType(op->getNonSuccessorOperands().getTypes(), - op->getResultTypes()); + printFunctionalType(op->getOperandTypes(), op->getResultTypes()); } /// Print the two given type ranges in a functional form. template @@ -188,6 +190,11 @@ return p << (value ? StringRef("true") : "false"); } +inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) { + p.printSuccessor(value); + return p; +} + template inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const ValueTypeRange &types) { @@ -574,15 +581,16 @@ // Successor Parsing //===--------------------------------------------------------------------===// + /// Parse a single operation successor. + virtual ParseResult parseSuccessor(Block *&dest) = 0; + + /// Parse an optional operation successor. + virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0; + /// Parse a single operation successor and its operand list. virtual ParseResult parseSuccessorAndUseList(Block *&dest, SmallVectorImpl &operands) = 0; - /// Parse an optional operation successor and its operand list. - virtual OptionalParseResult - parseOptionalSuccessorAndUseList(Block *&dest, - SmallVectorImpl &operands) = 0; - //===--------------------------------------------------------------------===// // Type Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -374,7 +374,7 @@ } //===--------------------------------------------------------------------===// - // Terminators + // Successors //===--------------------------------------------------------------------===// MutableArrayRef getBlockOperands() { @@ -387,24 +387,8 @@ succ_iterator successor_end() { return getSuccessors().end(); } SuccessorRange getSuccessors() { return SuccessorRange(this); } - /// Return the operands of this operation that are *not* successor arguments. - operand_range getNonSuccessorOperands(); - - operand_range getSuccessorOperands(unsigned index); - - Value getSuccessorOperand(unsigned succIndex, unsigned opIndex) { - assert(!isKnownNonTerminator() && "only terminators may have successors"); - assert(opIndex < getNumSuccessorOperands(succIndex)); - return getOperand(getSuccessorOperandIndex(succIndex) + opIndex); - } - bool hasSuccessors() { return numSuccs != 0; } unsigned getNumSuccessors() { return numSuccs; } - unsigned getNumSuccessorOperands(unsigned index) { - assert(!isKnownNonTerminator() && "only terminators may have successors"); - assert(index < getNumSuccessors()); - return getBlockOperands()[index].numSuccessorOperands; - } Block *getSuccessor(unsigned index) { assert(index < getNumSuccessors()); @@ -412,37 +396,6 @@ } void setSuccessor(Block *block, unsigned index); - /// Erase a specific operand from the operand list of the successor at - /// 'index'. - void eraseSuccessorOperand(unsigned succIndex, unsigned opIndex) { - assert(succIndex < getNumSuccessors()); - assert(opIndex < getNumSuccessorOperands(succIndex)); - getOperandStorage().eraseOperand(getSuccessorOperandIndex(succIndex) + - opIndex); - --getBlockOperands()[succIndex].numSuccessorOperands; - } - - /// Get the index of the first operand of the successor at the provided - /// index. - unsigned getSuccessorOperandIndex(unsigned index); - - /// Return a pair (successorIndex, successorArgIndex) containing the index - /// of the successor that `operandIndex` belongs to and the index of the - /// argument to that successor that `operandIndex` refers to. - /// - /// If `operandIndex` is not a successor operand, None is returned. - Optional> - decomposeSuccessorOperandIndex(unsigned operandIndex); - - /// Returns the `BlockArgument` corresponding to operand `operandIndex` in - /// some successor, or None if `operandIndex` isn't a successor operand index. - Optional getSuccessorBlockArgument(unsigned operandIndex) { - auto decomposed = decomposeSuccessorOperandIndex(operandIndex); - if (!decomposed.hasValue()) - return None; - return getSuccessor(decomposed->first)->getArgument(decomposed->second); - } - //===--------------------------------------------------------------------===// // Accessors for various properties of operations //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -40,6 +40,7 @@ class Region; class ResultRange; class RewritePattern; +class SuccessorRange; class Type; class Value; class ValueRange; @@ -316,7 +317,12 @@ attributes.append(newAttributes.begin(), newAttributes.end()); } - void addSuccessor(Block *successor, ValueRange succOperands); + /// Add an array of successors. + void addSuccessors(ArrayRef newSuccessors) { + successors.append(newSuccessors.begin(), newSuccessors.end()); + } + void addSuccessors(Block *successor) { successors.push_back(successor); } + void addSuccessors(SuccessorRange newSuccessors); /// Create a region that should be attached to the operation. These regions /// can be filled in immediately without waiting for Operation to be diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -295,13 +295,6 @@ /// Return which operand this is in the operand list of the User. unsigned getOperandNumber(); - -private: - /// The number of OpOperands that correspond with this block operand. - unsigned numSuccessorOperands = 0; - - /// Allow access to 'numSuccessorOperands'. - friend Operation; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -227,44 +227,13 @@ /// Hook for derived classes to implement rewriting. `op` is the (first) /// operation matched by the pattern, `operands` is a list of rewritten values /// that are passed to this operation, `rewriter` can be used to emit the new - /// operations. This function must be reimplemented if the - /// ConversionPattern ever needs to replace an operation that does not - /// have successors. This function should not fail. If some specific cases of + /// operations. This function should not fail. If some specific cases of /// the operation are not supported, these cases should not be matched. virtual void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { llvm_unreachable("unimplemented rewrite"); } - /// Hook for derived classes to implement rewriting. `op` is the (first) - /// operation matched by the pattern, `properOperands` is a list of rewritten - /// values that are passed to the operation itself, `destinations` is a list - /// of (potentially rewritten) successor blocks, `operands` is a list of lists - /// of rewritten values passed to each of the successors, co-indexed with - /// `destinations`, `rewriter` can be used to emit the new operations. It must - /// be reimplemented if the ConversionPattern ever needs to replace a - /// terminator operation that has successors. This function should not fail - /// the pass. If some specific cases of the operation are not supported, - /// these cases should not be matched. - virtual void rewrite(Operation *op, ArrayRef properOperands, - ArrayRef destinations, - ArrayRef> operands, - ConversionPatternRewriter &rewriter) const { - llvm_unreachable("unimplemented rewrite for terminators"); - } - - /// Hook for derived classes to implement combined matching and rewriting. - virtual PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef properOperands, - ArrayRef destinations, - ArrayRef> operands, - ConversionPatternRewriter &rewriter) const { - if (!match(op)) - return matchFailure(); - rewrite(op, properOperands, destinations, operands, rewriter); - return matchSuccess(); - } - /// Hook for derived classes to implement combined matching and rewriting. virtual PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -297,21 +266,6 @@ ConversionPatternRewriter &rewriter) const final { rewrite(cast(op), operands, rewriter); } - void rewrite(Operation *op, ArrayRef properOperands, - ArrayRef destinations, - ArrayRef> operands, - ConversionPatternRewriter &rewriter) const final { - rewrite(cast(op), properOperands, destinations, operands, - rewriter); - } - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef properOperands, - ArrayRef destinations, - ArrayRef> operands, - ConversionPatternRewriter &rewriter) const final { - return matchAndRewrite(cast(op), properOperands, destinations, - operands, rewriter); - } PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { @@ -328,24 +282,6 @@ llvm_unreachable("must override matchAndRewrite or a rewrite method"); } - virtual void rewrite(SourceOp op, ArrayRef properOperands, - ArrayRef destinations, - ArrayRef> operands, - ConversionPatternRewriter &rewriter) const { - llvm_unreachable("unimplemented rewrite for terminators"); - } - - virtual PatternMatchResult - matchAndRewrite(SourceOp op, ArrayRef properOperands, - ArrayRef destinations, - ArrayRef> operands, - ConversionPatternRewriter &rewriter) const { - if (!match(op)) - return matchFailure(); - rewrite(op, properOperands, destinations, operands, rewriter); - return matchSuccess(); - } - virtual PatternMatchResult matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -90,8 +90,7 @@ // Add branch before inserted body, into body. block = block->getNextNode(); - rewriter.create(loc, ArrayRef{}, - llvm::makeArrayRef(block), ValueRange()); + rewriter.create(loc, ValueRange(), block); // Replace all gpu.yield ops with branch out of body. for (; block != split; block = block->getNextNode()) { @@ -100,8 +99,7 @@ continue; rewriter.setInsertionPointToEnd(block); rewriter.replaceOpWithNewOp( - terminator, ArrayRef{}, llvm::makeArrayRef(split), - ValueRange(terminator->getOperand(0))); + terminator, terminator->getOperand(0), split); } // Return accumulator result. @@ -254,13 +252,10 @@ Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(loc, llvm::makeArrayRef(condition), - ArrayRef{thenBlock, elseBlock}); + rewriter.create(loc, condition, thenBlock, elseBlock); auto addBranch = [&](ValueRange operands) { - rewriter.create(loc, ArrayRef{}, - llvm::makeArrayRef(continueBlock), - llvm::makeArrayRef(operands)); + rewriter.create(loc, operands, continueBlock); }; rewriter.setInsertionPointToStart(thenBlock); @@ -645,8 +640,7 @@ PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, operands, - ArrayRef()); + rewriter.replaceOpWithNewOp(op, operands); return matchSuccess(); } }; diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -2185,13 +2185,10 @@ using Super = OneToOneLLVMTerminatorLowering; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef properOperands, - ArrayRef destinations, - ArrayRef> operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - SmallVector operandRanges(operands.begin(), operands.end()); - rewriter.replaceOpWithNewOp(op, properOperands, destinations, - operandRanges, op->getAttrs()); + rewriter.replaceOpWithNewOp(op, operands, op->getSuccessors(), + op->getAttrs()); return this->matchSuccess(); } }; @@ -2213,13 +2210,12 @@ // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments == 0) { rewriter.replaceOpWithNewOp( - op, ArrayRef(), ArrayRef(), op->getAttrs()); + op, ArrayRef(), ArrayRef(), op->getAttrs()); return matchSuccess(); } if (numArguments == 1) { rewriter.replaceOpWithNewOp( - op, ArrayRef(operands.front()), ArrayRef(), - op->getAttrs()); + op, ArrayRef(), operands.front(), op->getAttrs()); return matchSuccess(); } @@ -2234,8 +2230,8 @@ op->getLoc(), packedType, packed, operands[i], rewriter.getI64ArrayAttr(i)); } - rewriter.replaceOpWithNewOp( - op, llvm::makeArrayRef(packed), ArrayRef(), op->getAttrs()); + rewriter.replaceOpWithNewOp(op, ArrayRef(), packed, + op->getAttrs()); return matchSuccess(); } }; @@ -2742,10 +2738,8 @@ auto memRefType = atomicOp.getMemRefType(); auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter, getModule()); - auto init = rewriter.create(loc, dataPtr); - std::array brRegionOperands{init}; - std::array brOperands{brRegionOperands}; - rewriter.create(loc, ArrayRef{}, loopBlock, brOperands); + Value init = rewriter.create(loc, dataPtr); + rewriter.create(loc, init, loopBlock); // Prepare the body of the loop block. rewriter.setInsertionPointToStart(loopBlock); @@ -2768,19 +2762,14 @@ loc, pairType, dataPtr, loopArgument, select, successOrdering, failureOrdering); // Extract the %new_loaded and %ok values from the pair. - auto newLoaded = rewriter.create( + Value newLoaded = rewriter.create( loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); - auto ok = rewriter.create( + Value ok = rewriter.create( loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); // Conditionally branch to the end or back to the loop depending on %ok. - std::array condBrProperOperands{ok}; - std::array condBrDestinations{endBlock, loopBlock}; - std::array condBrRegionOperands{newLoaded}; - std::array condBrOperands{ArrayRef{}, - condBrRegionOperands}; - rewriter.create(loc, condBrProperOperands, - condBrDestinations, condBrOperands); + rewriter.create(loc, ok, endBlock, ArrayRef(), + loopBlock, newLoaded); // The 'result' of the atomic_rmw op is the newly loaded value. rewriter.replaceOp(op, {newLoaded}); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -250,9 +250,9 @@ return success(); } -///===----------------------------------------------------------------------===// +///===---------------------------------------------------------------------===// /// LLVM::InvokeOp -///===----------------------------------------------------------------------===// +///===---------------------------------------------------------------------===// Optional InvokeOp::getSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); @@ -278,7 +278,7 @@ return success(); } -static void printInvokeOp(OpAsmPrinter &p, InvokeOp &op) { +static void printInvokeOp(OpAsmPrinter &p, InvokeOp op) { auto callee = op.callee(); bool isDirect = callee.hasValue(); @@ -292,17 +292,16 @@ p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; p << " to "; - p.printSuccessorAndUseList(op.getOperation(), 0); + p.printSuccessorAndUseList(op.normalDest(), op.normalDestOperands()); p << " unwind "; - p.printSuccessorAndUseList(op.getOperation(), 1); - - p.printOptionalAttrDict(op.getAttrs(), {"callee"}); - - SmallVector argTypes( - llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); + p.printSuccessorAndUseList(op.unwindDest(), op.unwindDestOperands()); - p << " : " - << FunctionType::get(argTypes, op.getResultTypes(), op.getContext()); + p.printOptionalAttrDict(op.getAttrs(), + {InvokeOp::getOperandSegmentSizeAttr(), "callee"}); + p << " : "; + p.printFunctionalType( + llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1), + op.getResultTypes()); } /// ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)` @@ -316,6 +315,7 @@ llvm::SMLoc trailingTypeLoc; Block *normalDest, *unwindDest; SmallVector normalOperands, unwindOperands; + Builder &builder = parser.getBuilder(); // Parse an operand list that will, in practice, contain 0 or 1 operand. In // case of an indirect call, there will be 1 operand before `(`. In case of a @@ -351,7 +351,6 @@ return parser.emitError(trailingTypeLoc, "expected function with 0 or 1 result"); - Builder &builder = parser.getBuilder(); auto *llvmDialect = builder.getContext()->getRegisteredDialect(); LLVM::LLVMType llvmResultType; @@ -390,8 +389,15 @@ result.addTypes(llvmResultType); } - result.addSuccessor(normalDest, normalOperands); - result.addSuccessor(unwindDest, unwindOperands); + result.addSuccessors({normalDest, unwindDest}); + result.addOperands(normalOperands); + result.addOperands(unwindOperands); + + result.addAttribute( + InvokeOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({static_cast(operands.size()), + static_cast(normalOperands.size()), + static_cast(unwindOperands.size())})); return success(); } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -971,7 +971,6 @@ auto &builder = parser.getBuilder(); OpAsmParser::OperandType condInfo; Block *dest; - SmallVector destOperands; // Parse the condition. Type boolTy = builder.getI1Type(); @@ -996,17 +995,24 @@ } // Parse the true branch. + SmallVector trueOperands; if (parser.parseComma() || - parser.parseSuccessorAndUseList(dest, destOperands)) + parser.parseSuccessorAndUseList(dest, trueOperands)) return failure(); - state.addSuccessor(dest, destOperands); + state.addSuccessors(dest); + state.addOperands(trueOperands); // Parse the false branch. - destOperands.clear(); + SmallVector falseOperands; if (parser.parseComma() || - parser.parseSuccessorAndUseList(dest, destOperands)) + parser.parseSuccessorAndUseList(dest, falseOperands)) return failure(); - state.addSuccessor(dest, destOperands); + state.addSuccessors(dest); + state.addOperands(falseOperands); + state.addAttribute( + spirv::BranchConditionalOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({1, static_cast(trueOperands.size()), + static_cast(falseOperands.size())})); return success(); } @@ -1024,11 +1030,11 @@ } printer << ", "; - printer.printSuccessorAndUseList(branchOp.getOperation(), - spirv::BranchConditionalOp::kTrueIndex); + printer.printSuccessorAndUseList(branchOp.getTrueBlock(), + branchOp.getTrueBlockArguments()); printer << ", "; - printer.printSuccessorAndUseList(branchOp.getOperation(), - spirv::BranchConditionalOp::kFalseIndex); + printer.printSuccessorAndUseList(branchOp.getFalseBlock(), + branchOp.getFalseBlockArguments()); } static LogicalResult verify(spirv::BranchConditionalOp branchOp) { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1964,9 +1964,13 @@ /*withKeyword=*/true); } + /// Print the given successor. + void printSuccessor(Block *successor) override; + /// Print an operation successor with the operands used for the block /// arguments. - void printSuccessorAndUseList(Operation *term, unsigned index) override; + void printSuccessorAndUseList(Block *successor, + ValueRange succOperands) override; /// Print the given region. void printRegion(Region ®ion, bool printEntryBlockArgs, @@ -2062,23 +2066,14 @@ os << '"'; printEscapedString(op->getName().getStringRef(), os); os << "\"("; - - // Get the list of operands that are not successor operands. - unsigned totalNumSuccessorOperands = 0; - unsigned numSuccessors = op->getNumSuccessors(); - for (unsigned i = 0; i < numSuccessors; ++i) - totalNumSuccessorOperands += op->getNumSuccessorOperands(i); - unsigned numProperOperands = op->getNumOperands() - totalNumSuccessorOperands; - interleaveComma(op->getOperands().take_front(numProperOperands), - [&](Value value) { printValueID(value); }); - + interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); }); os << ')'; // For terminators, print the list of successors and their operands. - if (numSuccessors != 0) { + if (op->getNumSuccessors() != 0) { os << '['; - interleaveComma(llvm::seq(0, numSuccessors), - [&](unsigned i) { printSuccessorAndUseList(op, i); }); + interleaveComma(op->getSuccessors(), + [&](Block *successor) { printBlockName(successor); }); os << ']'; } @@ -2167,12 +2162,14 @@ state->getSSANameState().printValueID(value, printResultNo, os); } -void OperationPrinter::printSuccessorAndUseList(Operation *term, - unsigned index) { - printBlockName(term->getSuccessor(index)); +void OperationPrinter::printSuccessor(Block *successor) { + printBlockName(successor); +} - auto succOperands = term->getSuccessorOperands(index); - if (succOperands.begin() == succOperands.end()) +void OperationPrinter::printSuccessorAndUseList(Block *successor, + ValueRange succOperands) { + printBlockName(successor); + if (succOperands.empty()) return; os << '('; diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -111,14 +111,10 @@ NamedAttributeList attributes, ArrayRef successors, unsigned numRegions, bool resizableOperandList) { - unsigned numSuccessors = successors.size(); - // We only need to allocate additional memory for a subset of results. unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); - - // Input operands are nullptr-separated for each successor, the null operands - // aren't actually stored. - unsigned numOperands = operands.size() - numSuccessors; + unsigned numSuccessors = successors.size(); + unsigned numOperands = operands.size(); // Compute the byte size for the operation and the operand storage. auto byteSize = totalSizeToAllocgetRegion(i)) Region(op); - // Initialize the results and operands. + // Initialize the operands. new (&op->getOperandStorage()) detail::OperandStorage(numOperands, resizableOperandList); auto opOperands = op->getOpOperands(); + for (unsigned i = 0; i != numOperands; ++i) + new (&opOperands[i]) OpOperand(op, operands[i]); - // Initialize normal operands. - unsigned operandIt = 0, operandE = operands.size(); - unsigned nextOperand = 0; - for (; operandIt != operandE; ++operandIt) { - // Null operands are used as sentinels between successor operand lists. If - // we encounter one here, break and handle the successor operands lists - // separately below. - if (!operands[operandIt]) - break; - new (&opOperands[nextOperand++]) OpOperand(op, operands[operandIt]); - } - - unsigned currentSuccNum = 0; - if (operandIt == operandE) { - // Verify that the amount of sentinel operands is equivalent to the number - // of successors. - assert(currentSuccNum == numSuccessors); - return op; - } - - assert(!op->isKnownNonTerminator() && - "Unexpected nullptr in operand list when creating non-terminator."); - auto instBlockOperands = op->getBlockOperands(); - unsigned *succOperandCount = nullptr; - - for (; operandIt != operandE; ++operandIt) { - // If we encounter a sentinel branch to the next operand update the count - // variable. - if (!operands[operandIt]) { - assert(currentSuccNum < numSuccessors); - - new (&instBlockOperands[currentSuccNum]) - BlockOperand(op, successors[currentSuccNum]); - succOperandCount = - &instBlockOperands[currentSuccNum].numSuccessorOperands; - ++currentSuccNum; - continue; - } - new (&opOperands[nextOperand++]) OpOperand(op, operands[operandIt]); - ++(*succOperandCount); - } - - // Verify that the amount of sentinel operands is equivalent to the number of - // successors. - assert(currentSuccNum == numSuccessors); + // Initialize the successors. + auto blockOperands = op->getBlockOperands(); + for (unsigned i = 0; i != numSuccessors; ++i) + new (&blockOperands[i]) BlockOperand(op, successors[i]); return op; } @@ -564,49 +521,6 @@ getBlockOperands()[index].set(block); } -auto Operation::getNonSuccessorOperands() -> operand_range { - return getOperands().take_front(hasSuccessors() ? getSuccessorOperandIndex(0) - : getNumOperands()); -} - -/// Get the index of the first operand of the successor at the provided -/// index. -unsigned Operation::getSuccessorOperandIndex(unsigned index) { - assert(!isKnownNonTerminator() && "only terminators may have successors"); - assert(index < getNumSuccessors()); - - // Count the number of operands for each of the successors after, and - // including, the one at 'index'. This is based upon the assumption that all - // non successor operands are placed at the beginning of the operand list. - auto blockOperands = getBlockOperands().drop_front(index); - unsigned postSuccessorOpCount = - std::accumulate(blockOperands.begin(), blockOperands.end(), 0u, - [](unsigned cur, const BlockOperand &operand) { - return cur + operand.numSuccessorOperands; - }); - return getNumOperands() - postSuccessorOpCount; -} - -Optional> -Operation::decomposeSuccessorOperandIndex(unsigned operandIndex) { - assert(!isKnownNonTerminator() && "only terminators may have successors"); - assert(operandIndex < getNumOperands()); - unsigned currentOperandIndex = getNumOperands(); - auto blockOperands = getBlockOperands(); - for (unsigned i = 0, e = getNumSuccessors(); i < e; i++) { - unsigned successorIndex = e - i - 1; - currentOperandIndex -= blockOperands[successorIndex].numSuccessorOperands; - if (currentOperandIndex <= operandIndex) - return std::make_pair(successorIndex, operandIndex - currentOperandIndex); - } - return None; -} - -auto Operation::getSuccessorOperands(unsigned index) -> operand_range { - unsigned succOperandIndex = getSuccessorOperandIndex(index); - return getOperands().slice(succOperandIndex, getNumSuccessorOperands(index)); -} - /// Attempt to fold this operation using the Op's registered foldHook. LogicalResult Operation::fold(ArrayRef operands, SmallVectorImpl &results) { @@ -645,39 +559,20 @@ SmallVector operands; SmallVector successors; - operands.reserve(getNumOperands() + getNumSuccessors()); + // Remap the operands. + operands.reserve(getNumOperands()); + for (auto opValue : getOperands()) + operands.push_back(mapper.lookupOrDefault(opValue)); - if (getNumSuccessors() == 0) { - // Non-branching operations can just add all the operands. - for (auto opValue : getOperands()) - operands.push_back(mapper.lookupOrDefault(opValue)); - } else { - // We add the operands separated by nullptr's for each successor. - unsigned firstSuccOperand = - getNumSuccessors() ? getSuccessorOperandIndex(0) : getNumOperands(); - auto opOperands = getOpOperands(); - - unsigned i = 0; - for (; i != firstSuccOperand; ++i) - operands.push_back(mapper.lookupOrDefault(opOperands[i].get())); - - successors.reserve(getNumSuccessors()); - for (unsigned succ = 0, e = getNumSuccessors(); succ != e; ++succ) { - successors.push_back(mapper.lookupOrDefault(getSuccessor(succ))); - - // Add sentinel to delineate successor operands. - operands.push_back(nullptr); - - // Remap the successors operands. - for (auto operand : getSuccessorOperands(succ)) - operands.push_back(mapper.lookupOrDefault(operand)); - } - } + // Remap the successors. + successors.reserve(getNumSuccessors()); + for (Block *successor : getSuccessors()) + successors.push_back(mapper.lookupOrDefault(successor)); - unsigned numRegions = getNumRegions(); - auto *newOp = - Operation::create(getLoc(), getName(), getResultTypes(), operands, attrs, - successors, numRegions, hasResizableOperandsList()); + // Create the new operation. + auto *newOp = Operation::create(getLoc(), getName(), getResultTypes(), + operands, attrs, successors, getNumRegions(), + hasResizableOperandsList()); // Remember the mapping of any results. for (unsigned i = 0, e = getNumResults(); i != e; ++i) diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -42,15 +42,11 @@ } void OperationState::addOperands(ValueRange newOperands) { - assert(successors.empty() && "Non successor operands should be added first."); operands.append(newOperands.begin(), newOperands.end()); } -void OperationState::addSuccessor(Block *successor, ValueRange succOperands) { - successors.push_back(successor); - // Insert a sentinel operand to mark a barrier between successor operands. - operands.push_back(nullptr); - operands.append(succOperands.begin(), succOperands.end()); +void OperationState::addSuccessors(SuccessorRange newSuccessors) { + successors.append(newSuccessors.begin(), newSuccessors.end()); } Region *OperationState::addRegion() { diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -3301,13 +3301,11 @@ /// Parse an operation instance. ParseResult parseOperation(); - /// Parse a single operation successor and its operand list. - ParseResult parseSuccessorAndUseList(Block *&dest, - SmallVectorImpl &operands); + /// Parse a single operation successor. + ParseResult parseSuccessor(Block *&dest); /// Parse a comma-separated list of operation successors in brackets. - ParseResult parseSuccessors(SmallVectorImpl &destinations, - SmallVectorImpl> &operands); + ParseResult parseSuccessors(SmallVectorImpl &destinations); /// Parse an operation instance that is in the generic form. Operation *parseGenericOperation(); @@ -3797,27 +3795,16 @@ return success(); } -/// Parse a single operation successor and its operand list. +/// Parse a single operation successor. /// -/// successor ::= block-id branch-use-list? -/// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)` +/// successor ::= block-id /// -ParseResult -OperationParser::parseSuccessorAndUseList(Block *&dest, - SmallVectorImpl &operands) { +ParseResult OperationParser::parseSuccessor(Block *&dest) { // Verify branch is identifier and get the matching block. if (!getToken().is(Token::caret_identifier)) return emitError("expected block name"); dest = getBlockNamed(getTokenSpelling(), getToken().getLoc()); consumeToken(); - - // Handle optional arguments. - if (consumeIf(Token::l_paren) && - (parseOptionalSSAUseAndTypeList(operands) || - parseToken(Token::r_paren, "expected ')' to close argument list"))) { - return failure(); - } - return success(); } @@ -3825,18 +3812,15 @@ /// /// successor-list ::= `[` successor (`,` successor )* `]` /// -ParseResult OperationParser::parseSuccessors( - SmallVectorImpl &destinations, - SmallVectorImpl> &operands) { +ParseResult +OperationParser::parseSuccessors(SmallVectorImpl &destinations) { if (parseToken(Token::l_square, "expected '['")) return failure(); - auto parseElt = [this, &destinations, &operands]() { + auto parseElt = [this, &destinations] { Block *dest; - SmallVector destOperands; - auto res = parseSuccessorAndUseList(dest, destOperands); + ParseResult res = parseSuccessor(dest); destinations.push_back(dest); - operands.push_back(destOperands); return res; }; return parseCommaSeparatedListUntil(Token::r_square, parseElt, @@ -3880,24 +3864,23 @@ // Parse the operand list. SmallVector operandInfos; - if (parseToken(Token::l_paren, "expected '(' to start operand list") || parseOptionalSSAUseList(operandInfos) || parseToken(Token::r_paren, "expected ')' to end operand list")) { return nullptr; } - // Parse the successor list but don't add successors to the result yet to - // avoid messing up with the argument order. - SmallVector successors; - SmallVector, 2> successorOperands; + // Parse the successor list. if (getToken().is(Token::l_square)) { // Check if the operation is a known terminator. const AbstractOperation *abstractOp = result.name.getAbstractOperation(); if (abstractOp && !abstractOp->hasProperty(OperationProperty::Terminator)) return emitError("successors in non-terminator"), nullptr; - if (parseSuccessors(successors, successorOperands)) + + SmallVector successors; + if (parseSuccessors(successors)) return nullptr; + result.addSuccessors(successors); } // Parse the region list. @@ -3948,13 +3931,6 @@ return nullptr; } - // Add the successors, and their operands after the proper operands. - for (auto succ : llvm::zip(successors, successorOperands)) { - Block *successor = std::get<0>(succ); - const SmallVector &operands = std::get<1>(succ); - result.addSuccessor(successor, operands); - } - // Parse a location if one is present. if (parseOptionalTrailingLocation(result.location)) return nullptr; @@ -4421,20 +4397,31 @@ // Successor Parsing //===--------------------------------------------------------------------===// - /// Parse a single operation successor and its operand list. - ParseResult - parseSuccessorAndUseList(Block *&dest, - SmallVectorImpl &operands) override { - return parser.parseSuccessorAndUseList(dest, operands); + /// Parse a single operation successor. + ParseResult parseSuccessor(Block *&dest) override { + return parser.parseSuccessor(dest); } /// Parse an optional operation successor and its operand list. - OptionalParseResult - parseOptionalSuccessorAndUseList(Block *&dest, - SmallVectorImpl &operands) override { + OptionalParseResult parseOptionalSuccessor(Block *&dest) override { if (parser.getToken().isNot(Token::caret_identifier)) return llvm::None; - return parseSuccessorAndUseList(dest, operands); + return parseSuccessor(dest); + } + + /// Parse a single operation successor and its operand list. + ParseResult + parseSuccessorAndUseList(Block *&dest, + SmallVectorImpl &operands) override { + if (parseSuccessor(dest)) + return failure(); + + // Handle optional arguments. + if (succeeded(parseOptionalLParen()) && + (parser.parseOptionalSSAUseAndTypeList(operands) || parseRParen())) { + return failure(); + } + return success(); } //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -634,21 +634,29 @@ auto *brInst = cast(inst); OperationState state(loc, brInst->isConditional() ? "llvm.cond_br" : "llvm.br"); - SmallVector ops; if (brInst->isConditional()) { Value condition = processValue(brInst->getCondition()); if (!condition) return failure(); - ops.push_back(condition); + state.addOperands(condition); } - state.addOperands(ops); - SmallVector succs; - for (auto *succ : llvm::reverse(brInst->successors())) { + + std::array operandSegmentSizes = {1, 0, 0}; + for (int i : llvm::seq(0, brInst->getNumSuccessors())) { + auto *succ = brInst->getSuccessor(i); SmallVector blockArguments; if (failed(processBranchArgs(brInst, succ, blockArguments))) return failure(); - state.addSuccessor(blocks[succ], blockArguments); + state.addSuccessors(blocks[succ]); + state.addOperands(blockArguments); + operandSegmentSizes[i + 1] = blockArguments.size(); } + + if (brInst->isConditional()) { + state.addAttribute(LLVM::CondBrOp::getOperandSegmentSizeAttr(), + b.getI32VectorAttr(operandSegmentSizes)); + } + b.createOperation(state); return success(); } diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1005,33 +1005,7 @@ SmallVector operands; auto &dialectRewriter = static_cast(rewriter); dialectRewriter.getImpl().remapValues(op->getOperands(), operands); - - // If this operation has no successors, invoke the rewrite directly. - if (op->getNumSuccessors() == 0) - return matchAndRewrite(op, operands, dialectRewriter); - - // Otherwise, we need to remap the successors. - SmallVector destinations; - destinations.reserve(op->getNumSuccessors()); - - SmallVector, 2> operandsPerDestination; - unsigned firstSuccessorOperand = op->getSuccessorOperandIndex(0); - for (unsigned i = 0, seen = 0, e = op->getNumSuccessors(); i < e; ++i) { - destinations.push_back(op->getSuccessor(i)); - - // Lookup the successors operands. - unsigned n = op->getNumSuccessorOperands(i); - operandsPerDestination.push_back( - llvm::makeArrayRef(operands.data() + firstSuccessorOperand + seen, n)); - seen += n; - } - - // Rewrite the operation. - return matchAndRewrite( - op, - llvm::makeArrayRef(operands.data(), - operands.data() + firstSuccessorOperand), - destinations, operandsPerDestination, dialectRewriter); + return matchAndRewrite(op, operands, dialectRewriter); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/RegionUtils.h" +#include "mlir/Analysis/ControlFlowInterfaces.h" #include "mlir/IR/Block.h" #include "mlir/IR/Operation.h" #include "mlir/IR/RegionGraphTraits.h" @@ -172,8 +173,9 @@ // node, rather than to the terminator op itself, a terminator op can't e.g. // "print" the value of a successor operand. if (owner->isKnownTerminator()) { - if (auto arg = owner->getSuccessorBlockArgument(operandIndex)) - return !liveMap.wasProvenLive(*arg); + if (BranchOpInterface branchInterface = dyn_cast(owner)) + if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex)) + return !liveMap.wasProvenLive(*arg); return false; } return false; @@ -200,6 +202,29 @@ } static void propagateLiveness(Region ®ion, LiveMap &liveMap); + +static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) { + // Terminators are always live. + liveMap.setProvedLive(op); + + // Check to see if we can reason about the successor operands and mutate them. + BranchOpInterface branchInterface = dyn_cast(op); + if (!branchInterface || !branchInterface.canEraseSuccessorOperand()) { + for (Block *successor : op->getSuccessors()) + for (BlockArgument arg : successor->getArguments()) + liveMap.setProvedLive(arg); + return; + } + + // If we can't reason about the operands to a successor, conservatively mark + // all arguments as live. + for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { + if (!branchInterface.getSuccessorOperands(i)) + for (BlockArgument arg : op->getSuccessor(i)->getArguments()) + liveMap.setProvedLive(arg); + } +} + static void propagateLiveness(Operation *op, LiveMap &liveMap) { // All Value's are either a block argument or an op result. // We call processValue on those cases. @@ -208,6 +233,10 @@ for (Region ®ion : op->getRegions()) propagateLiveness(region, liveMap); + // Process terminator operations. + if (op->isKnownTerminator()) + return propagateTerminatorLiveness(op, liveMap); + // Process the op itself. if (isOpIntrinsicallyLive(op)) { liveMap.setProvedLive(op); @@ -238,6 +267,10 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator, LiveMap &liveMap) { + BranchOpInterface branchOp = dyn_cast(terminator); + if (!branchOp) + return; + for (unsigned succI = 0, succE = terminator->getNumSuccessors(); succI < succE; succI++) { // Iterating successors in reverse is not strictly needed, since we @@ -245,15 +278,17 @@ // since it will promote later operands of the terminator being erased // first, reducing the quadratic-ness. unsigned succ = succE - succI - 1; - for (unsigned argI = 0, argE = terminator->getNumSuccessorOperands(succ); - argI < argE; argI++) { + Optional succOperands = branchOp.getSuccessorOperands(succ); + if (!succOperands) + continue; + Block *successor = terminator->getSuccessor(succ); + + for (unsigned argI = 0, argE = succOperands->size(); argI < argE; ++argI) { // Iterating args in reverse is needed for correctness, to avoid // shifting later args when earlier args are erased. unsigned arg = argE - argI - 1; - Value value = terminator->getSuccessor(succ)->getArgument(arg); - if (!liveMap.wasProvenLive(value)) { - terminator->eraseSuccessorOperand(succ, arg); - } + if (!liveMap.wasProvenLive(successor->getArgument(arg))) + branchOp.eraseSuccessorOperand(succ, arg); } } } diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -24,7 +24,7 @@ // ----- func @missing_accessor() -> () { - // expected-error @+1 {{requires 1 successor but found 0}} + // expected-error @+2 {{expected block name}} spv.Branch } @@ -117,7 +117,7 @@ func @wrong_accessor_count() -> () { %true = spv.constant true // expected-error @+1 {{requires 2 successors but found 1}} - "spv.BranchConditional"(%true)[^one] : (i1) -> () + "spv.BranchConditional"(%true)[^one] {operand_segment_sizes = dense<[1, 0, 0]>: vector<3xi32>} : (i1) -> () ^one: spv.Return ^two: @@ -129,7 +129,8 @@ func @wrong_number_of_weights() -> () { %true = spv.constant true // expected-error @+1 {{must have exactly two branch weights}} - "spv.BranchConditional"(%true)[^one, ^two] {branch_weights = [1 : i32, 2 : i32, 3 : i32]} : (i1) -> () + "spv.BranchConditional"(%true)[^one, ^two] {branch_weights = [1 : i32, 2 : i32, 3 : i32], + operand_segment_sizes = dense<[1, 0, 0]>: vector<3xi32>} : (i1) -> () ^one: spv.Return ^two: diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -153,7 +153,7 @@ func @block_arg_no_close_paren() { ^bb42: - br ^bb2( // expected-error@+1 {{expected ')' to close argument list}} + br ^bb2( // expected-error@+1 {{expected ':'}} return } diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -439,11 +439,11 @@ func @verbose_terminators() -> (i1, i17) { %0:2 = "foo"() : () -> (i1, i17) // CHECK: br ^bb1(%{{.*}}#0, %{{.*}}#1 : i1, i17) - "std.br"()[^bb1(%0#0, %0#1 : i1, i17)] : () -> () + "std.br"(%0#0, %0#1)[^bb1] : (i1, i17) -> () ^bb1(%x : i1, %y : i17): // CHECK: cond_br %{{.*}}, ^bb2(%{{.*}} : i17), ^bb3(%{{.*}}, %{{.*}} : i1, i17) - "std.cond_br"(%x)[^bb2(%y : i17), ^bb3(%x, %y : i1, i17)] : (i1) -> () + "std.cond_br"(%x, %y, %x, %y) [^bb2, ^bb3] {operand_segment_sizes = dense<[1, 1, 2]>: vector<3xi32>} : (i1, i17, i1, i17) -> () ^bb2(%a : i17): %true = constant 1 : i1 @@ -844,8 +844,8 @@ // CHECK-LABEL: func @unregistered_term func @unregistered_term(%arg0 : i1) -> i1 { - // CHECK-NEXT: "unregistered_br"()[^bb1(%{{.*}} : i1)] : () -> () - "unregistered_br"()[^bb1(%arg0 : i1)] : () -> () + // CHECK-NEXT: "unregistered_br"(%{{.*}})[^bb1] : (i1) -> () + "unregistered_br"(%arg0)[^bb1] : (i1) -> () ^bb1(%arg1 : i1): return %arg1 : i1 diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir --- a/mlir/test/Transforms/canonicalize-dce.mlir +++ b/mlir/test/Transforms/canonicalize-dce.mlir @@ -20,7 +20,7 @@ // CHECK-NEXT: return func @f(%arg0: f32) { - "test.br"()[^succ(%arg0: f32)] : () -> () + "test.br"(%arg0)[^succ] : (f32) -> () ^succ(%0: f32): return } @@ -141,7 +141,7 @@ // Test case: Test the mechanics of deleting multiple block arguments. // CHECK: func @f(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>, %arg2: tensor<3xf32>, %arg3: tensor<4xf32>, %arg4: tensor<5xf32>) -// CHECK-NEXT: "test.br"()[^bb1(%arg1, %arg3 : tensor<2xf32>, tensor<4xf32>) +// CHECK-NEXT: "test.br"(%arg1, %arg3)[^bb1] : (tensor<2xf32>, tensor<4xf32>) // CHECK-NEXT: ^bb1([[VAL0:%.+]]: tensor<2xf32>, [[VAL1:%.+]]: tensor<4xf32>): // CHECK-NEXT: "foo.print"([[VAL0]]) // CHECK-NEXT: "foo.print"([[VAL1]]) @@ -154,7 +154,7 @@ %arg2: tensor<3xf32>, %arg3: tensor<4xf32>, %arg4: tensor<5xf32>) { - "test.br"()[^succ(%arg0, %arg1, %arg2, %arg3, %arg4 : tensor<1xf32>, tensor<2xf32>, tensor<3xf32>, tensor<4xf32>, tensor<5xf32>)] : () -> () + "test.br"(%arg0, %arg1, %arg2, %arg3, %arg4)[^succ] : (tensor<1xf32>, tensor<2xf32>, tensor<3xf32>, tensor<4xf32>, tensor<5xf32>) -> () ^succ(%t1: tensor<1xf32>, %t2: tensor<2xf32>, %t3: tensor<3xf32>, %t4: tensor<4xf32>, %t5: tensor<5xf32>): "foo.print"(%t2) : (tensor<2xf32>) -> () "foo.print"(%t4) : (tensor<4xf32>) -> () diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -462,6 +462,7 @@ def TestBranchOp : TEST_Op<"br", [DeclareOpInterfaceMethods, Terminator]> { + let arguments = (ins Variadic:$targetOperands); let successors = (successor AnySuccessor:$target); } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -609,9 +609,8 @@ if (successor.name.empty()) continue; - // Generate the accessors for a variadic successor. + // Generate the accessors for a variadic successor list. if (successor.isVariadic()) { - // Generate the getter. auto &m = opClass.newMethod("SuccessorRange", successor.name); m.body() << formatv( " return {std::next(this->getOperation()->successor_begin(), {0}), " @@ -620,21 +619,8 @@ continue; } - // Generate the block getter. auto &m = opClass.newMethod("Block *", successor.name); m.body() << formatv(" return this->getOperation()->getSuccessor({0});", i); - - // Generate the all-operands getter. - auto &operandsMethod = opClass.newMethod( - "Operation::operand_range", (successor.name + "Operands").str()); - operandsMethod.body() << formatv( - " return this->getOperation()->getSuccessorOperands({0});", i); - - // Generate the individual-operand getter. - auto &operandMethod = opClass.newMethod( - "Value", (successor.name + "Operand").str(), "unsigned index"); - operandMethod.body() << formatv( - " return this->getOperation()->getSuccessorOperand({0}, index);", i); } } @@ -1044,14 +1030,9 @@ } /// Insert parameters for the block and operands for each successor. - const char *variadicSuccCode = - ", ArrayRef {0}, ArrayRef {0}Operands"; - const char *succCode = ", Block *{0}, ValueRange {0}Operands"; - for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) { - if (namedSuccessor.isVariadic()) - paramList += llvm::formatv(variadicSuccCode, namedSuccessor.name).str(); - else - paramList += llvm::formatv(succCode, namedSuccessor.name).str(); + for (const NamedSuccessor &succ : op.getSuccessors()) { + paramList += (succ.isVariadic() ? ", ArrayRef " : ", Block *"); + paramList += succ.name; } } @@ -1123,14 +1104,7 @@ // Push all successors to the result. for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) { - if (namedSuccessor.isVariadic()) { - body << formatv(" for (int i = 0, e = {1}.size(); i != e; ++i)\n" - " {0}.addSuccessor({1}[i], {1}Operands[i]);\n", - builderOpState, namedSuccessor.name); - continue; - } - - body << formatv(" {0}.addSuccessor({1}, {1}Operands);\n", builderOpState, + body << formatv(" {0}.addSuccessors({1});\n", builderOpState, namedSuccessor.name); } } @@ -1488,9 +1462,7 @@ int numVariadicOperands = op.getNumVariadicOperands(); // Add operand size trait. - // Note: Successor operands are also included in the operation's operand list, - // so we always need to use VariadicOperands in the presence of successors. - if (numVariadicOperands != 0 || op.getNumSuccessors()) { + if (numVariadicOperands != 0) { if (numOperands == numVariadicOperands) opClass.addTrait("OpTrait::VariadicOperands"); else diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -418,24 +418,20 @@ /// /// {0}: The name for the successor list. const char *successorListParserCode = R"( - SmallVector>, 2> {0}Successors; + SmallVector {0}Successors; { Block *succ; - SmallVector succOperands; - // Parse the first successor. - auto firstSucc = parser.parseOptionalSuccessorAndUseList(succ, - succOperands); + auto firstSucc = parser.parseOptionalSuccessor(succ); if (firstSucc.hasValue()) { if (failed(*firstSucc)) return failure(); - {0}Successors.emplace_back(succ, succOperands); + {0}Successors.emplace_back(succ); // Parse any trailing successors. while (succeeded(parser.parseOptionalComma())) { - succOperands.clear(); - if (parser.parseSuccessorAndUseList(succ, succOperands)) + if (parser.parseSuccessor(succ)) return failure(); - {0}Successors.emplace_back(succ, succOperands); + {0}Successors.emplace_back(succ); } } } @@ -446,19 +442,10 @@ /// {0}: The name of the successor. const char *successorParserCode = R"( Block *{0}Successor = nullptr; - SmallVector {0}Operands; - if (parser.parseSuccessorAndUseList({0}Successor, {0}Operands)) + if (parser.parseSuccessor({0}Successor)) return failure(); )"; -/// The code snippet used to resolve a list of parsed successors. -/// -/// {0}: The name of the successor list. -const char *resolveSuccessorListParserCode = R"( - for (auto &succAndArgs : {0}Successors) - result.addSuccessor(succAndArgs.first, succAndArgs.second); -)"; - /// Get the name used for the type list for the given type directive operand. /// 'isVariadic' is set to true if the operand has variadic types. static StringRef getTypeListName(Element *arg, bool &isVariadic) { @@ -802,19 +789,16 @@ bool hasAllSuccessors = llvm::any_of( elements, [](auto &elt) { return isa(elt.get()); }); if (hasAllSuccessors) { - body << llvm::formatv(resolveSuccessorListParserCode, "full"); + body << " result.addSuccessors(fullSuccessors);\n"; return; } // Otherwise, handle each successor individually. for (const NamedSuccessor &successor : op.getSuccessors()) { - if (successor.isVariadic()) { - body << llvm::formatv(resolveSuccessorListParserCode, successor.name); - continue; - } - - body << llvm::formatv(" result.addSuccessor({0}Successor, {0}Operands);\n", - successor.name); + if (successor.isVariadic()) + body << " result.addSuccessors(" << successor.name << "Successors);\n"; + else + body << " result.addSuccessors(" << successor.name << "Successor);\n"; } } @@ -957,28 +941,14 @@ body << " p << " << operand->getVar()->name << "();\n"; } else if (auto *successor = dyn_cast(element)) { const NamedSuccessor *var = successor->getVar(); - if (var->isVariadic()) { - body << " {\n" - << " auto succRange = " << var->name << "();\n" - << " auto opSuccBegin = getOperation()->successor_begin();\n" - << " int i = succRange.begin() - opSuccBegin;\n" - << " int e = i + succRange.size();\n" - << " interleaveComma(llvm::seq(i, e), p, [&](int i) {\n" - << " p.printSuccessorAndUseList(*this, i);\n" - << " });\n" - << " }\n"; - return; - } - - unsigned index = successor->getVar() - op.successor_begin(); - body << " p.printSuccessorAndUseList(*this, " << index << ");\n"; + if (var->isVariadic()) + body << " interleaveComma(" << var->name << "(), p);\n"; + else + body << " p << " << var->name << "();\n"; } else if (isa(element)) { body << " p << getOperation()->getOperands();\n"; } else if (isa(element)) { - body << " interleaveComma(llvm::seq(0, " - "getOperation()->getNumSuccessors()), p, [&](int i) {" - << " p.printSuccessorAndUseList(*this, i);" - << " });\n"; + body << " interleaveComma(getOperation()->getSuccessors(), p);\n"; } else if (auto *dir = dyn_cast(element)) { body << " p << "; genTypeOperandPrinter(dir->getOperand(), body) << ";\n";